Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Question about custom dataset finetuning #45

Closed
mystorm16 opened this issue Jul 3, 2024 · 4 comments
Closed

Question about custom dataset finetuning #45

mystorm16 opened this issue Jul 3, 2024 · 4 comments
Assignees

Comments

@mystorm16
Copy link

Hi, thanks for the great work.

When I fine-tune with acid.ckpt or re10k.ckpt:python3 -m src.main +experiment=custom checkpointing.load=checkpoints/re10k.ckpt mode=train data_loader.train.batch_size=4

an error occurs:KeyError: 'Trying to restore optimizer state but checkpoint contains only the model. This is probably due to ModelCheckpoint.save_weights_onlybeing set toTrue.'

Is my finetune command incorrect?

@Langwenchong
Copy link

I am also facing this issue. It seems that the problem arises because when PL tries to restore the model from the checkpoint file, it also needs to read the corresponding optimizer state parameters (such as learning rate) from when the training was terminated. However, the model file provided by the author has already filtered out these unnecessary parameters, leaving only the model parameters, which causes the error. I wonder if the author could additionally provide a checkpoint file that contains all the required state parameters for finetuning🫡?

@donydchen
Copy link
Owner

Hi @mystorm16 and @Langwenchong, thanks for your interest in our work.

To fine-tune from the released weight, you can initialize the model from the existing checkpoint and skip the checkpoint path in the fit function. Below, I have provided a workaround solution for your reference.

Change the model initialization from

mvsplat/src/main.py

Lines 123 to 132 in 378ff81

model_wrapper = ModelWrapper(
cfg.optimizer,
cfg.test,
cfg.train,
encoder,
encoder_visualizer,
get_decoder(cfg.model.decoder, cfg.dataset),
get_losses(cfg.loss),
step_tracker
)
to

model_kwargs = {
    "optimizer_cfg": cfg.optimizer,
    "test_cfg": cfg.test,
    "train_cfg": cfg.train,
    "encoder": encoder,
    "encoder_visualizer": encoder_visualizer,
    "decoder": get_decoder(cfg.model.decoder, cfg.dataset),
    "losses": get_losses(cfg.loss),
    "step_tracker": step_tracker,
}
model_wrapper = ModelWrapper.load_from_checkpoint(
    checkpoint_path, **model_kwargs, strict=True, map_location="cpu",
)

Then, change the fit function from

trainer.fit(model_wrapper, datamodule=data_module, ckpt_path=checkpoint_path)
to

trainer.fit(model_wrapper, datamodule=data_module, ckpt_path=None)

You can confirm the setting by checking the first validation logged at step 0, which should show a good visual result. I will find time in the following weeks to update the code to support fine-tuning. Feel free to let me know if you have any other questions or suggestions.

@mystorm16
Copy link
Author

mystorm16 commented Jul 4, 2024

Thanks for the quick reply!
I made this modification and had a good visual result, is this the same:

    model_state_dict = encoder.state_dict()
    checkpoint = torch.load('checkpoints/acid.ckpt')
    checkpoint_state_dict = checkpoint['state_dict']
    for key in model_state_dict:
        if 'encoder.'+key in checkpoint_state_dict:
            if model_state_dict[key].shape == checkpoint_state_dict['encoder.'+key].shape:
                model_state_dict[key].copy_(checkpoint_state_dict['encoder.'+key])
            else:
                print(f"Shape mismatch for parameter {key}. Skipping...")
    encoder.load_state_dict(model_state_dict)

@donydchen donydchen self-assigned this Jul 5, 2024
@donydchen
Copy link
Owner

Hi @mystorm16, I think your solution does the same thing as the one I provided above since the decoder actually has no trainable parameters. Cheers.

donydchen added a commit that referenced this issue Jul 18, 2024
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

No branches or pull requests

3 participants