Skip to content

[Bug] Infinite loop in HMC._find_reasonable_step_size #3156

Closed
@saitcakmak

Description

Issue Description

In some edge cases, the while loop in HMC._find_reasonable_step_size never exits and becomes an infinite loop. This happens because the direction never changes and the step_size eventually becomes 0. The below printout of t & step_size shows an example of this.

Screen Shot 2022-11-18 at 11 34 59 PM

Environment

For any bugs, please provide the following:

  • OS and python version: CentOS 8, Python 3.8
  • PyTorch version, or if relevant, output of pip freeze. Latest main.
  • Pyro version: output of python -c 'import pyro; print pyro.__version__' 1.5.2 (checked that the error is not fixed in latest)

Code Snippet

Unfortunately, I do not have a clean code snippet that reproduces the issue. I have extracted the HMC object and the input z that produces the infinite loop and used torch.save to save them. With the attached output, run

import torch
save_dict = torch.load("/home/saitcakmak/pyro_debug.pt")  # update the file path
save_dict["hmc"]._find_reasonable_step_size(save_dict["z"])

pyro_debug.pt.zip

Quick fix

For a hot-fix, I updated the following line

while direction_new == direction:

with while direction_new == direction and step_size > 1e-100: which prevents the infinite loop.

cc @dme65, @Balandat

Metadata

Assignees

No one assigned

    Labels

    Type

    No type

    Projects

    No projects

    Milestone

    No milestone

    Relationships

    None yet

    Development

    No branches or pull requests

    Issue actions