-
-
Notifications
You must be signed in to change notification settings - Fork 985
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Make ReparamMessenger aware of InitMessenger, simplifying initialization #2876
Conversation
I like to have this feature implemented. This has been requested several times previously. |
logger.info("Heuristic init: {}".format(", ".join( | ||
"{}={:0.3g}".format(k, v.item()) | ||
for k, v in sorted(init_values.items()) | ||
if v.numel() == 1))) | ||
return init_to_value(values=init_values) | ||
return init_to_value(values=init_values, fallback=None) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I've added fallback logic to a few init strategies (hoping to port these to NumPyro and combine with pyro-ppl/numpyro#1058), and set the fallback to None
here, which would error if the new initialization logic were to fail.
site=None, | ||
num_samples=15, | ||
*, | ||
fallback: Optional[Callable] = init_to_feasible, |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I've added fallback
kwargs to a few init strategies mainly to make it easier for users to catch unintended initialization.
@@ -194,4 +234,8 @@ def _pyro_sample(self, msg): | |||
"{} provided invalid shape for site {}:\nexpected {}\nactual {}" | |||
.format(self.init_fn, msg["name"], msg["value"].shape, value.shape)) | |||
msg["value"] = value | |||
msg["done"] = True |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
@eb8680 can you confirm this is ok? I think we want to avoid setting done
here so as to enable subsequent reparametrization.
Looks like all tests are passing again, let me know if you have any questions. |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
The design looks great to me. Thanks for taking care of the user-defined strategies and backward compatibility. I just catch a small nit.
self.guide.get_base_dist().mask(False)) | ||
with ExitStack() as stack: | ||
for plate in self.guide.plates.values(): | ||
stack.enter_context(block_plate(dim=plate.dim, strict=False)) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I just wonder why we need to block_plate
here (and some other places). Could you elaborate a bit?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
The problem addressed is with in models like
@neutra.reparam()
def model():
with pyro.plate("plate", 10):
pyro.sample("x", dist.Normal(0,1))
where the z_unconconstrained
shared latent is a single latent variable that is not independent along batch dimensions; rather it should have empty batch_shape
(except for particle plates that are outside the guide and hence not blocked) and has event_shape = (fn.shape().numel(),)
. Without this block
statement the z_unconstrained
would have event_shape = (10,)
but would then be broadcast up by an enclosing plate to batch_shape + event_shape = (10, 10)
. This this block_plate
merely hoists that shared latent variable outside of all plates.
I believe this is what we want to do with NeuTra
, but IIUC it prevents "batched neutra" in that we won't be able to share a single normalizing flow along a batch dimension and subsample. I'm not sure whether that prevented use case is realistic though, since our HMC and NUTS do not support data subsampling.
if not torch._C._get_tracing_state(): | ||
assert new_msg["value"].shape == msg["value"].shape | ||
|
||
# Warn if a custom init method is overwritten by another init method. |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Sounds reasonable to me.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Yeah, this is from personal experience where I spent about 6 hours hunting down a bug that ended up being overwritten initialization 😅
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
LGTM after our walkthrough with extra tests and nits addressed
Addresses #2878
pair coded with @eb8680 and @fehiepsi
This aims to make initialization compatible with model reparametrization, for example
This involves a major interface change for reparametrizers.
The problem is that the
init_loc_fn
passed to autoguides and HMC are often customized to particular latent sites, but reparametrizers replace those sites with new auxiliary latent sites. The solution proposed in this PR is to require each reparametrizer to know how to transform original values to latent values, and to use that in initialization.Tasks deferred to future PRs
NeuTraReparam
StructuredReparam
Tested
msg["is_observed"]
pyro.contrib.epidemiology
is covered by existing unit testscheck_init_reparam()
tests for all reparametrizers (ensuring coverage)