Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

RunningNLLMetric returns infinite when running optimize_prior_precision #194

Closed
MJordahn opened this issue Jun 11, 2024 · 3 comments · Fixed by #199
Closed

RunningNLLMetric returns infinite when running optimize_prior_precision #194

MJordahn opened this issue Jun 11, 2024 · 3 comments · Fixed by #199
Milestone

Comments

@MJordahn
Copy link

MJordahn commented Jun 11, 2024

Discussion started in #193.

After running fit (using the AsdlGGN backend) in FullLLLaplace, I can do predictions using the Laplace module (and see that my predictive function changes in comparison to my MAP model - but in a reasonable way). When I then run optimize_prior_precision using the gridsearch method, the models predictive function breaks.

When I use the progress_bar option in the Github version of the repo I get the following (just a snippet):

  0%|          | 0/21 [00:00<?, ?it/s]
[Grid search | prior_prec: 1.000e-04, loss: inf]:   0%|          | 0/21 [00:22<?, ?it/s]
[Grid search | prior_prec: 1.000e-04, loss: inf]:   5%|▍         | 1/21 [00:22<07:39, 22.97s/it]
[Grid search | prior_prec: 2.512e-04, loss: inf]:   5%|▍         | 1/21 [00:44<07:39, 22.97s/it]
[Grid search | prior_prec: 2.512e-04, loss: inf]:  10%|▉         | 2/21 [00:44<07:00, 22.12s/it]
[Grid search | prior_prec: 6.310e-04, loss: inf]:  10%|▉         | 2/21 [01:07<07:00, 22.12s/it]
[Grid search | prior_prec: 6.310e-04, loss: inf]:  14%|█▍        | 3/21 [01:07<06:43, 22.43s/it]
[Grid search | prior_prec: 1.585e-03, loss: inf]:  14%|█▍        | 3/21 [01:30<06:43, 22.43s/it]
[Grid search | prior_prec: 1.585e-03, loss: inf]:  19%|█▉        | 4/21 [01:30<06:23, 22.57s/it]
[Grid search | prior_prec: 3.981e-03, loss: inf]:  19%|█▉        | 4/21 [01:52<06:23, 22.57s/it]
[Grid search | prior_prec: 3.981e-03, loss: inf]:  24%|██▍       | 5/21 [01:52<06:02, 22.64s/it]
[Grid search | prior_prec: 1.000e-02, loss: inf]:  24%|██▍       | 5/21 [02:15<06:02, 22.64s/it]
[Grid search | prior_prec: 1.000e-02, loss: inf]:  29%|██▊       | 6/21 [02:15<05:40, 22.70s/it]
[Grid search | prior_prec: 2.512e-02, loss: inf]:  29%|██▊       | 6/21 [02:38<05:40, 22.70s/it]
[Grid search | prior_prec: 2.512e-02, loss: inf]:  33%|███▎      | 7/21 [02:38<05:18, 22.73s/it]
[Grid search | prior_prec: 6.310e-02, loss: inf]:  33%|███▎      | 7/21 [03:01<05:18, 22.73s/it]
[Grid search | prior_prec: 6.310e-02, loss: inf]:  38%|███▊      | 8/21 [03:01<04:55, 22.75s/it]
[Grid search | prior_prec: 1.585e-01, loss: inf]:  38%|███▊      | 8/21 [03:24<04:55, 22.75s/it]
[Grid search | prior_prec: 1.585e-01, loss: inf]:  43%|████▎     | 9/21 [03:24<04:33, 22.77s/it]

When I compute NLL using Pytorch NLLLoss, and using the model that has only been fitted but not had the prior optimized, the outputs are reasonable.

I have just seen issue #157 which I suppose is related to this problem.

@MJordahn
Copy link
Author

MJordahn commented Jun 11, 2024

I have found a solution for my case. The problem was stemming from the fact that the passed loss is never moved to the same device that the rest of the Laplace class is on. It was silently failing due to the try statement as mentioned in #157.

My solution was just to move the loss function class to the Laplace module's own device (_device) inside optimize_prior_precision_base of BaseLaplace (If loss is None I just instantiate RunningNLLMetric directly on _device, otherwise I move it there).

I am not sure if it is a pretty solution but it works for me now (at least I am not getting inf values when running optimize_prior_precision).

@wiseodd
Copy link
Collaborator

wiseodd commented Jun 11, 2024

Thanks! Yes, that try statement is indeed insidious. Maybe letting it fails (i.e. remove the try-except) is a better design. @runame, thoughts?

@runame
Copy link
Collaborator

runame commented Jun 11, 2024

Yea removing the try-except statement sounds like the way to go, or alternatively restricting which error will be caught and not catching any RuntimeError.

@wiseodd wiseodd linked a pull request Jun 14, 2024 that will close this issue
@wiseodd wiseodd added this to the 0.2 milestone Jun 15, 2024
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging a pull request may close this issue.

3 participants