Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Make ReparamMessenger aware of InitMessenger, simplifying initialization #2876

Merged
merged 24 commits into from
Jun 29, 2021

Conversation

fritzo
Copy link
Member

@fritzo fritzo commented Jun 17, 2021

Addresses #2878
pair coded with @eb8680 and @fehiepsi

This aims to make initialization compatible with model reparametrization, for example

def model():
    loc = pyro.sample("loc", dist.Normal(0, 1))
    scale = pyro.sample("scale", dist.LogNormal(0, 1))
    pyro.sample("x", dist.Normal(loc, scale))

reparam_model = pyro.reparam(model, {"x": LocScaleReparam()})
reparam_guide = AutoNormal(reparam_model, init_loc_fn=init_to_value({"x": torch.tensor(0.0)}))
assert reparam_guide()["x"] == 0

This involves a major interface change for reparametrizers.

The problem is that the init_loc_fn passed to autoguides and HMC are often customized to particular latent sites, but reparametrizers replace those sites with new auxiliary latent sites. The solution proposed in this PR is to require each reparametrizer to know how to transform original values to latent values, and to use that in initialization.

Tasks deferred to future PRs

  • Support value initialization in NeuTraReparam
  • Support value initialization in StructuredReparam

Tested

  • tests for all new reparameterizers
  • tests for nesting and msg["is_observed"]
  • refactoring of pyro.contrib.epidemiology is covered by existing unit tests
  • add check_init_reparam() tests for all reparametrizers (ensuring coverage)

@fehiepsi
Copy link
Member

I like to have this feature implemented. This has been requested several times previously.

@fritzo fritzo marked this pull request as ready for review June 23, 2021 17:26
@fritzo fritzo requested a review from fehiepsi June 23, 2021 17:53
logger.info("Heuristic init: {}".format(", ".join(
"{}={:0.3g}".format(k, v.item())
for k, v in sorted(init_values.items())
if v.numel() == 1)))
return init_to_value(values=init_values)
return init_to_value(values=init_values, fallback=None)
Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I've added fallback logic to a few init strategies (hoping to port these to NumPyro and combine with pyro-ppl/numpyro#1058), and set the fallback to None here, which would error if the new initialization logic were to fail.

site=None,
num_samples=15,
*,
fallback: Optional[Callable] = init_to_feasible,
Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I've added fallback kwargs to a few init strategies mainly to make it easier for users to catch unintended initialization.

@@ -194,4 +234,8 @@ def _pyro_sample(self, msg):
"{} provided invalid shape for site {}:\nexpected {}\nactual {}"
.format(self.init_fn, msg["name"], msg["value"].shape, value.shape))
msg["value"] = value
msg["done"] = True
Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@eb8680 can you confirm this is ok? I think we want to avoid setting done here so as to enable subsequent reparametrization.

@fritzo fritzo requested a review from eb8680 June 23, 2021 17:57
@fritzo
Copy link
Member Author

fritzo commented Jun 23, 2021

@eb8680 @fehiepsi let me know if you'd like to do a code walk-through. This is a big refactoring and I'd like to eventually port it to NumPyro.

@fritzo fritzo removed the refactor label Jun 23, 2021
@fritzo
Copy link
Member Author

fritzo commented Jun 25, 2021

Looks like all tests are passing again, let me know if you have any questions.

Copy link
Member

@fehiepsi fehiepsi left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The design looks great to me. Thanks for taking care of the user-defined strategies and backward compatibility. I just catch a small nit.

pyro/infer/autoguide/initialization.py Outdated Show resolved Hide resolved
self.guide.get_base_dist().mask(False))
with ExitStack() as stack:
for plate in self.guide.plates.values():
stack.enter_context(block_plate(dim=plate.dim, strict=False))
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I just wonder why we need to block_plate here (and some other places). Could you elaborate a bit?

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The problem addressed is with in models like

@neutra.reparam()
def model():
    with pyro.plate("plate", 10):
        pyro.sample("x", dist.Normal(0,1))

where the z_unconconstrained shared latent is a single latent variable that is not independent along batch dimensions; rather it should have empty batch_shape (except for particle plates that are outside the guide and hence not blocked) and has event_shape = (fn.shape().numel(),). Without this block statement the z_unconstrained would have event_shape = (10,) but would then be broadcast up by an enclosing plate to batch_shape + event_shape = (10, 10). This this block_plate merely hoists that shared latent variable outside of all plates.

I believe this is what we want to do with NeuTra, but IIUC it prevents "batched neutra" in that we won't be able to share a single normalizing flow along a batch dimension and subsample. I'm not sure whether that prevented use case is realistic though, since our HMC and NUTS do not support data subsampling.

if not torch._C._get_tracing_state():
assert new_msg["value"].shape == msg["value"].shape

# Warn if a custom init method is overwritten by another init method.
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Sounds reasonable to me.

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yeah, this is from personal experience where I spent about 6 hours hunting down a bug that ended up being overwritten initialization 😅

@fritzo fritzo added this to the 1.7 release milestone Jun 28, 2021
Copy link
Member

@eb8680 eb8680 left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

LGTM after our walkthrough with extra tests and nits addressed

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Projects
None yet
Development

Successfully merging this pull request may close these issues.

3 participants