diff --git a/pyro/poutine/equalize_messenger.py b/pyro/poutine/equalize_messenger.py index 1bc79a5521..31e524a651 100644 --- a/pyro/poutine/equalize_messenger.py +++ b/pyro/poutine/equalize_messenger.py @@ -68,7 +68,7 @@ def __init__( self, sites: Union[str, List[str]], type: Optional[str] = "sample", - keep_dist: bool = False, + keep_dist: Optional[bool] = False, ) -> None: super().__init__() self.sites = [sites] if isinstance(sites, str) else sites diff --git a/pyro/poutine/handlers.py b/pyro/poutine/handlers.py index 343b1a1f4b..bc1ba91de8 100644 --- a/pyro/poutine/handlers.py +++ b/pyro/poutine/handlers.py @@ -306,7 +306,8 @@ def escape( # type: ignore[empty-body] def equalize( sites: Union[str, List[str]], type: Optional[str], -) -> ConditionMessenger: ... + keep_dist: Optional[bool], +) -> EqualizeMessenger: ... @overload @@ -314,6 +315,7 @@ def equalize( fn: Callable[_P, _T], sites: Union[str, List[str]], type: Optional[str], + keep_dist: Optional[bool], ) -> Callable[_P, _T]: ... @@ -322,6 +324,7 @@ def equalize( # type: ignore[empty-body] fn: Callable[_P, _T], sites: Union[str, List[str]], type: Optional[str], + keep_dist: Optional[bool], ) -> Union[EqualizeMessenger, Callable[_P, _T]]: ...