Skip to content

[BUG] Numerical Instability issues with torchrl.modules.TanhNormal  #2199

Open
@N00bcak

Description

@N00bcak

Describe the bug

When training on PettingZoo/MultiWalker-v9 with Multi-Agent Soft Actor-Critic, all losses (loss_actor, loss_qvalue, loss_alpha) explode after ~1M environment steps at most.

This phenomenon occurs regardless of (reasonable) hyperparameter and gradient clipping threshold choice.

To Reproduce

from copy import deepcopy
import tqdm
import numpy as np
from gymnasium.spaces import Box

import logging
import math

import torch
from torch import nn
import torch.distributions as D

from torchrl.data.replay_buffers import TensorDictReplayBuffer
from torchrl.data.replay_buffers.samplers import RandomSampler
from torchrl.data.replay_buffers.storages import LazyMemmapStorage 

from torchrl.envs import (
    check_env_specs,
    PettingZooEnv, 
    ParallelEnv,
    GymEnv
)

from torchrl.modules import AdditiveGaussianWrapper, ProbabilisticActor
from torchrl.modules.models import MLP
from torchrl.modules.models.multiagent import (
    MultiAgentMLP,
    MultiAgentNetBase
)
from torchrl.collectors import SyncDataCollector, MultiSyncDataCollector, RandomPolicy

from torchrl.objectives import SACLoss, SoftUpdate

from tensordict import TensorDict
from tensordict.nn import TensorDictModule, TensorDictSequential, NormalParamExtractor
from torchrl.envs.utils import ExplorationType, set_exploration_type
from torchrl.envs import EnvCreator, TransformedEnv, Compose, Transform, RewardSum, ObservationNorm, StepCounter
from torchrl.record import CSVLogger, VideoRecorder, PixelRenderTransform

import multiprocessing as mp

EPS = 1e-7
class SMACCNet(MultiAgentNetBase): 
    '''
    This is an MLP policy network for MultiAgent SAC.
    
    This is just a more limited version of MultiAgentMLP.
    (https://pytorch.org/rl/main/_modules/torchrl/modules/models/multiagent.html)
    '''
    def __init__(self, 
                n_agent_inputs: int | None,
                n_agent_outputs: int,
                n_agents: int,
                centralised: bool,
                share_params: bool,
                device = 'cpu',
                activation_class = nn.Tanh,
                **kwargs):

        self.n_agents = n_agents
        self.n_agent_inputs = n_agent_inputs
        self.n_agent_outputs = n_agent_outputs
        self.share_params = share_params
        self.centralised = centralised
        self.activation_class = activation_class
        self.device = device

        super().__init__(
            n_agents=n_agents,
            centralised=centralised,
            share_params=share_params,
            agent_dim=-2,
            device = device,
            **kwargs,
        )

    # Copied over from MultiAgentMLP.
    def _pre_forward_check(self, inputs):
        if inputs.shape[-2] != self.n_agents:
            raise ValueError(
                f"Multi-agent network expected input with shape[-2]={self.n_agents},"
                f" but got {inputs.shape}"
            )
        # If the model is centralized, agents have full observability
        if self.centralised:
            inputs = inputs.flatten(-2, -1)
        return inputs

    def init_net_params(self, net):
        def init_layer_params(layer):
            if isinstance(layer, nn.Linear):
                weight_gain = 1. / (100 if layer.out_features == self.n_agent_outputs else 1)
                torch.nn.init.xavier_uniform_(layer.weight, gain = weight_gain)
                if 'bias' in layer.state_dict():
                    torch.nn.init.zeros_(layer.bias)
        net.apply(init_layer_params)
        return net

    def _build_single_net(self, *, device, **kwargs):
        n_agent_inputs = self.n_agent_inputs
        if self.centralised and n_agent_inputs is not None:
            n_agent_inputs = self.n_agent_inputs * self.n_agents
        
        model = nn.Sequential(
            nn.Linear(n_agent_inputs, 400),
            self.activation_class(),
            nn.Linear(400, 300),
            self.activation_class(),
            nn.Linear(300, self.n_agent_outputs)
        ).to(self.device) # Bandaid fix to use MultiSyncDataCollector

        model = self.init_net_params(model)

        return model

class TqdmLoggingHandler(logging.Handler):
    def __init__(self, level=logging.NOTSET):
        super().__init__(level)

    def emit(self, record):
        try:
            msg = self.format(record)
            tqdm.tqdm.write(msg)
            self.flush()
        except Exception:
            self.handleError(record)  

# Main Function
if __name__ == "__main__":
    logging.basicConfig(level = logging.INFO)
    logger = logging.getLogger(__name__)
    logger.propagate = False
    logger.addHandler(TqdmLoggingHandler())

    mp.set_start_method("spawn", force = True)
    
    NUM_AGENTS = 3
    NUM_CRITICS = 2
    NUM_EXPLORE_WORKERS = 8
    EXPLORATION_STEPS = 30000
    MAX_EPISODE_STEPS = 1000
    DEVICE = "cuda"
    REPLAY_BUFFER_SIZE = int(1e6)
    VALUE_GAMMA = 0.99
    MAX_GRAD_NORM = 1.0
    BATCH_SIZE = 256
    LR = 1e-4
    UPDATE_STEPS_PER_EXPLORATION = 1500
    WARMUP_STEPS = 0 #int(2e5)
    TRAIN_TIMESTEPS = int(1e7)
    EVAL_INTERVAL = 1 #int(9e4 // EXPLORATION_STEPS) # Every 500k steps or so, evaluate once.

    SEED = 42
    torch.manual_seed(SEED)
    np.random.seed(SEED)

    # https://pytorch.org/rl/stable/tutorials/multiagent_competitive_ddpg.html
    # More tutorials: https://pytorch.org/tutorials/advanced/pendulum.html
    # Toy test: https://pettingzoo.farama.org/environments/sisl/multiwalker/def env_fn(mode, parallel = True):
        def base_env_fn():
            return PettingZooEnv(task = "multiwalker_v9", 
                                    parallel = True,
                                    seed = 42,
                                    n_walkers = NUM_AGENTS, 
                                    terminate_reward = -5.0,
                                    forward_reward = 1.0,
                                    fall_reward = -1.0,
                                    shared_reward = False, 
                                    max_cycles = MAX_EPISODE_STEPS, 
                                    render_mode = mode, 
                                    device = "cpu"
                                )

        if parallel:
            # Don't use.
            # https://discuss.pytorch.org/t/pettingzoo-trouble-running-multiple-marl-environments-in-parallel/203706/

            env = lambda: ParallelEnv(num_workers = 4,  # noqa: E731
                                        create_env_fn = base_env_fn, 
                                        device = "cpu",
                                        mp_start_method = "spawn",
                                        serial_for_single = True
                                    )
        else:
            env = base_env_fn # noqa: E731

        def env_with_transforms():
            # dummy_env = base_env_fn()
            # dummy_obs_transform = ObservationNorm(in_keys = [("walker", "observation")], standard_normal = True)
            # dummy_env = TransformedEnv(dummy_env, dummy_obs_transform)
            # dummy_obs_transform.init_stats(10000)


            init_env = env()
            # obs_transform = ObservationNorm(loc = dummy_obs_transform.loc + EPS,
            #                                 scale = dummy_obs_transform.scale + EPS,
            #                                 in_keys = [("walker", "observation")], 
            #                                 standard_normal = True
            #                             )
            init_env = TransformedEnv(init_env, Compose(
                                            StepCounter(max_steps = MAX_EPISODE_STEPS),
                                            RewardSum(
                                                in_keys = [init_env.reward_key for _ in range(NUM_AGENTS)], 
                                                out_keys = [("walker", "episode_reward")] * NUM_AGENTS, 
                                                reset_keys = ["_reset"] * NUM_AGENTS
                                            ),
                                            # obs_transform
                                        )
                                    )
            # del dummy_env, dummy_obs_transform
            return init_env

        return env_with_transforms

    train_env = env_fn(None, parallel = False)()

    if train_env.is_closed:
        train_env.start()
    eval_env = env_fn("rgb_array", parallel = False)()
    video_recorder = VideoRecorder(
                                    CSVLogger("multiwalker-toy-test", video_format = "mp4"), 
                                    tag = "rendered", 
                                    in_keys = ["pixels_record"]
                                )
    
    # Call the parent's render function
    eval_env.append_transform(PixelRenderTransform(out_keys = ["pixels_record"]))
    eval_env.append_transform(video_recorder)

    if eval_env.is_closed:
        eval_env.start()

    check_env_specs(train_env)
    check_env_specs(eval_env)

    print(f"Action: {train_env.full_action_spec}, Reward: {train_env.full_reward_spec}, Done: {train_env.full_done_spec}, Observation: {train_env.full_observation_spec}")
    
    print(f"group_map: {train_env.group_map}")
    print(f"Action: {train_env.action_keys}, Reward: {train_env.reward_keys}, Done: {train_env.done_keys}")
    
    # NOTE: The input and output spaces to be fed in are on a PER-AGENT basis.
    # Basically, if you have 16 agents observing 3D velocity and outputting speed (the magnitude), 
    # n_agent_inputs = 3, n_agent_outputs = 1.  
    obs_dim = train_env.full_observation_spec["walker", "observation"].shape[-1]
    action_dim = train_env.full_action_spec["walker", "action"].shape[-1]

    policy_net = nn.Sequential(
                        SMACCNet(n_agent_inputs = obs_dim,
                          n_agent_outputs = 2 * action_dim, 
                          n_agents = NUM_AGENTS,
                          centralised = False, 
                          share_params = True, 
                          device = "cpu",
                          activation_class = nn.LeakyReLU, 
                        ),
                        NormalParamExtractor(),
                    )
    
    critic_net = SMACCNet(n_agent_inputs = obs_dim + action_dim,
                          n_agent_outputs = 1,
                          n_agents = NUM_AGENTS,
                          centralised = True,
                          share_params = True, 
                          device = "cpu",
                          activation_class = nn.LeakyReLU, 
                        )

    # Hook our networks to TensorDictModules so they can be a part of the TensorDict pipeline...
    policy_net_td_module = TensorDictModule(module = policy_net,
                                            in_keys = [("walker", "observation")],
                                            # NOTE: These outputs must match with the parameter names of the 
                                            # distribution you are using!
                                            out_keys = [("walker", "loc"), ("walker", "scale")]
                                        )
 
    obs_act_module = TensorDictModule(lambda obs, act: torch.cat([obs, act], dim = -1),
                                        in_keys = [("walker", "observation"), ("walker", "action")],
                                        out_keys = [("walker", "obs_act")]
                                    )
    critic_net_td_module = TensorDictModule(module = critic_net,
                                            in_keys = [("walker", "obs_act")],
                                            out_keys = [("walker", "state_action_value")]
                                        )

    # Attach our raw policy network to a probabilistic actor
    policy_actor = ProbabilisticActor(
        module = policy_net_td_module,
        spec = train_env.full_action_spec["walker", "action"],
        in_keys = [("walker", "loc"), ("walker", "scale")],
        out_keys = [("walker", "action")],
        # TanhNormal is based off of pytorch, which as far as we know, 
        # implements a numerically stable log det jacobian.
        distribution_class = TanhNormal,
        distribution_kwargs = {
            "min": train_env.full_action_spec["walker", "action"].space.low,
            "max": train_env.full_action_spec["walker", "action"].space.high,
        },
        return_log_prob = True,
    )
    
    with torch.no_grad():
        fake_td = train_env.fake_tensordict()
        policy_actor(fake_td)

    dora = AdditiveGaussianWrapper(
        policy = policy_actor,
        action_key = ("walker", "action"),
        sigma_init = 0.3,
        sigma_end = 0.1,
        annealing_num_steps = TRAIN_TIMESTEPS // 2
    )

    critic_actor = TensorDictSequential(
                            obs_act_module, critic_net_td_module
                        ) 

    collector = MultiSyncDataCollector(
                    [env_fn(None, parallel = False) for _ in range(NUM_EXPLORE_WORKERS)], 
                    policy = dora,
                    frames_per_batch = BATCH_SIZE,
                    max_frames_per_traj = 0,
                    total_frames = TRAIN_TIMESTEPS,
                    device = "cpu",
                    reset_at_each_iter = False
                )

    replay_buffer = TensorDictReplayBuffer(
        storage = LazyMemmapStorage(
            REPLAY_BUFFER_SIZE, device = "cpu",
        ),  # We will store up to memory_size multi-agent transitions
        sampler = RandomSampler(),
        batch_size = BATCH_SIZE,  # We will sample batches of this size
    )

    sac_loss = SACLoss(policy_actor.to(DEVICE), 
                        qvalue_network = critic_actor.to(DEVICE), 
                        num_qvalue_nets = 2,
                        loss_function = "l2",
                        delay_qvalue = True,
                        alpha_init = 0.1
                        )
    sac_loss.set_keys(
        action = ("walker", "action"),
        state_action_value = ("walker", "state_action_value"),
        reward = ("walker", "reward"),
        done = ("walker", "done"),
        terminated = ("walker", "terminated"),
    )

    sac_loss.make_value_estimator(gamma = VALUE_GAMMA)

    polyak_updater = SoftUpdate(sac_loss, tau = 0.005) 

    critic_params = list(sac_loss.qvalue_network_params.flatten_keys().values())
    actor_params = list(sac_loss.actor_network_params.flatten_keys().values())

    optimizer_actor = torch.optim.Adam(
        actor_params,
        lr = LR,
        weight_decay = 5e-4,
        eps = EPS,
        betas = (0.9, 0.98)
    )
    optimizer_critic = torch.optim.Adam(
        critic_params,
        lr = LR,
        weight_decay = 5e-4,
        eps = EPS,
        betas = (0.9, 0.98)
    )
    optimizer_alpha = torch.optim.Adam(
        [sac_loss.log_alpha],
        lr = LR,
        eps = EPS,
        betas = (0.9, 0.98)
    )

    # breakpoint()
    num_frames = 0
    pbar = tqdm.tqdm(total = TRAIN_TIMESTEPS)
    total_frames = 0
    backprop_ctr = 0
    train_rews, ep_lengths = [], []
    EXPLORATION_BATCHES = EXPLORATION_STEPS // BATCH_SIZE
    for i, tensordict in enumerate(collector):

        collector.update_policy_weights_()

        pbar.update(tensordict.numel())

        tensordict = tensordict.reshape(-1)
        current_frames = tensordict.numel()
        # Add to replay buffer
        replay_buffer.extend(tensordict.cpu())
        total_frames += current_frames

        backprop_ctr += 1
        # Optimization steps
        if total_frames >= WARMUP_STEPS and backprop_ctr > EXPLORATION_BATCHES:
            backprop_ctr = 0
            losses = TensorDict({}, batch_size = [UPDATE_STEPS_PER_EXPLORATION])
            alphas = TensorDict({}, batch_size = [UPDATE_STEPS_PER_EXPLORATION])
            for j in range(UPDATE_STEPS_PER_EXPLORATION):
                # Sample from replay buffer
                sampled_tensordict = replay_buffer.sample()

                if str(sampled_tensordict.device) != DEVICE:
                    sampled_tensordict = sampled_tensordict.to(DEVICE, non_blocking = False)
                else:
                    sampled_tensordict = sampled_tensordict.clone()

                try:
                    # Compute loss
                    loss_td = sac_loss(sampled_tensordict)
                except KeyError:
                    raise Exception(f"wtf {sampled_tensordict}\n{obs_act_module(sampled_tensordict)['walker', 'obs_act']}")


                actor_loss = loss_td["loss_actor"]
                q_loss = loss_td["loss_qvalue"]
                alpha_loss = loss_td["loss_alpha"]

                # Update actor
                optimizer_actor.zero_grad()
                actor_loss.backward()
                actor_grad_norm = torch.nn.utils.clip_grad_norm_(actor_params, max_norm = MAX_GRAD_NORM)
                optimizer_actor.step()

                # Update critic
                optimizer_critic.zero_grad()
                q_loss.backward()
                q_grad_norm = torch.nn.utils.clip_grad_norm_(critic_params, max_norm = MAX_GRAD_NORM)
                optimizer_critic.step()

                # Update alpha
                optimizer_alpha.zero_grad()
                alpha_loss.backward()
                alpha_grad_norm = torch.nn.utils.clip_grad_norm_([sac_loss.log_alpha], max_norm = MAX_GRAD_NORM)
                optimizer_alpha.step()

                losses[j] = loss_td.select(
                    "loss_actor", "loss_qvalue", "loss_alpha"
                ).detach()

                alphas[j] = loss_td.select("alpha")


                # Update qnet_target params
                polyak_updater.step()

            # Some other stuff I ripped out from https://github.com/pytorch/rl/blob/main/sota-implementations/sac/sac.py
            episode_end = (
                tensordict["next", "done"]
                if tensordict["next", "done"].any()
                else tensordict["next", "truncated"]
            )
            opening_banner = "-" * 10 + f" Batch {i + 1} " + "-" * 10

            def get_mean(src, key): 
                return src.get(key).mean().item()

            logger.info(opening_banner)
            logger.info(f"Average Actor Loss: {get_mean(losses, 'loss_actor')}")
            logger.info(f"Average Q Loss: {get_mean(losses, 'loss_qvalue')}")
            logger.info(f"Average Alpha: {get_mean(alphas, 'alpha')} (Loss: {get_mean(losses, 'loss_alpha')})")
            logger.info("-" * len(opening_banner))
            
            
            ep_length = tensordict['next', 'step_count'][episode_end].to(dtype = torch.float64)

            if ep_length.numel():
                ep_lengths.append(ep_length.mean().item())

        agent_terminated = torch.stack(
            [
                tensordict["next", "walker", "done"][:, agent_id, 0]
                if tensordict["next", "walker", "done"][:, agent_id, 0].any()
                else tensordict["next", "walker", "truncated"][:, agent_id, 0]
                for agent_id in range(NUM_AGENTS)
            ], 
            dim = 1
        )
        train_reward = tensordict['next', 'walker', 'episode_reward'][agent_terminated]

        if train_reward.numel():
            train_rews.append(train_reward.mean().item())

        if not ((i + 1) % (EVAL_INTERVAL * EXPLORATION_BATCHES)):
            logger.info(
                        f"Mean Train Reward Across Past {EVAL_INTERVAL} Collections: " +
                        (
                            f"{sum(train_rews) / len(train_rews)}" 
                            if len(train_rews) 
                            else f"NA (Training starts @ {WARMUP_STEPS} steps)"
                        )
                    )
            with set_exploration_type(ExplorationType.MODE), torch.no_grad():
                eval_rollout = eval_env.rollout(
                    MAX_EPISODE_STEPS,
                    policy_actor,
                    auto_cast_to_device=True,
                    break_when_any_done=True,
                )

                mean_eval_length = eval_rollout["next", "step_count"][-1].to(dtype = torch.float64).mean().item()
                logger.info(f"Mean Eval Reward: {eval_rollout['next', 'walker', 'episode_reward'][-1].mean().item()}")
                logger.info(f"Eval Length: {mean_eval_length}")
            ep_reward_list = []
            train_rews = []
            eval_env.transform.dump()

    collector.shutdown()
    train_env.close()

Expected behavior

Loss values stay within ~ +/- 10^2 throughout training and do not increase to ~ +/- 10^x where x >> 1.

System info

>>> import torchrl, numpy, sys
>>> print(f"TorchRL: {torchrl.__version__}\nNumPy: {numpy.__version__}\nPython3 Ver: {sys.version}\nPlatform: {sys.platform}")
TorchRL: 0.4.0 
NumPy: 1.25.0 
Python3 Ver: 3.10.12 (main, Nov 20 2023, 15:14:05) [GCC 11.4.0] 
Platform: linux
> lsb_release -a
Distributor ID:	Ubuntu
Description:	Ubuntu 22.04.3 LTS
Release:	22.04
Codename:	jammy

Reason and Possible fixes

Though the environment's observation space is not normalized and carries unbounded entries, the issue does not appear to entirely arise from the poor observation scaling, since adding a torchrl.envs.ObservationNorm does not mitigate the issue.

Debugging reveals that unusually large and negative values for log_prob are somehow being fed into the SACLoss calculations from the reimplementation of torch.distributions.transforms.TanhTransform.

class TanhNormal(FasterTransformedDistribution):
"""Implements a TanhNormal distribution with location scaling.
Location scaling prevents the location to be "too far" from 0 when a
``TanhTransform`` is applied, but ultimately
leads to numerically unstable samples and poor gradient computation
(e.g. gradient explosion).
In practice, with location scaling the location is computed according to
.. math::
loc = tanh(loc / upscale) * upscale.
Args:
loc (torch.Tensor): normal distribution location parameter
scale (torch.Tensor): normal distribution sigma parameter (squared root of variance)
upscale (torch.Tensor or number): 'a' scaling factor in the formula:
.. math::
loc = tanh(loc / upscale) * upscale.
min (torch.Tensor or number, optional): minimum value of the distribution. Default is -1.0;
max (torch.Tensor or number, optional): maximum value of the distribution. Default is 1.0;
event_dims (int, optional): number of dimensions describing the action.
Default is 1;
tanh_loc (bool, optional): if ``True``, the above formula is used for the location scaling, otherwise the raw
value is kept. Default is ``False``;
"""
arg_constraints = {
"loc": constraints.real,
"scale": constraints.greater_than(1e-6),
}
num_params = 2
def __init__(
self,
loc: torch.Tensor,
scale: torch.Tensor,
upscale: Union[torch.Tensor, Number] = 5.0,
min: Union[torch.Tensor, Number] = -1.0,
max: Union[torch.Tensor, Number] = 1.0,
event_dims: int = 1,
tanh_loc: bool = False,
):
err_msg = "TanhNormal max values must be strictly greater than min values"
if isinstance(max, torch.Tensor) or isinstance(min, torch.Tensor):
if not (max > min).all():
raise RuntimeError(err_msg)
elif isinstance(max, Number) and isinstance(min, Number):
if not max > min:
raise RuntimeError(err_msg)
else:
if not all(max > min):
raise RuntimeError(err_msg)
if isinstance(max, torch.Tensor):
self.non_trivial_max = (max != 1.0).any()
else:
self.non_trivial_max = max != 1.0
if isinstance(min, torch.Tensor):
self.non_trivial_min = (min != -1.0).any()
else:
self.non_trivial_min = min != -1.0
self.tanh_loc = tanh_loc
self._event_dims = event_dims
self.device = loc.device
self.upscale = (
upscale
if not isinstance(upscale, torch.Tensor)
else upscale.to(self.device)
)
if isinstance(max, torch.Tensor):
max = max.to(loc.device)
if isinstance(min, torch.Tensor):
min = min.to(loc.device)
self.min = min
self.max = max
t = SafeTanhTransform()
if self.non_trivial_max or self.non_trivial_min:
t = D.ComposeTransform(
[
t,
D.AffineTransform(loc=(max + min) / 2, scale=(max - min) / 2),
]
)
self._t = t

Since this reimplementation does not change much from the original TanhTransform, it is plausible that the reimplementation is NOT the root cause of the error. Nevertheless, replacing the reimplementation with an alternative variant gets rid of the issue altogether:

class CustomTanhTransform(D.transforms.TanhTransform):

    def _inverse(self, y):
        # from stable_baselines3's `common.distributions.TanhBijector`
        """
        Inverse of Tanh

        Taken from Pyro: https://github.com/pyro-ppl/pyro
        0.5 * torch.log((1 + x ) / (1 - x))
        """

        y = y.clamp(-1. + EPS, 1. - EPS)
        return 0.5 * (y.log1p() - (-y).log1p())

    def log_abs_det_jacobian(self, x, y):
        # From PyTorch `TanhTransform`
        '''
        tl;dr log(1-tanh^2(x)) = log(sech^2(x)) 
                               = 2log(2/(e^x + e^(-x))) 
                               = 2(log2 - log(e^x/(1 + e^(-2x)))
                               = 2(log2 - x - log(1 + e^(-2x)))
                               = 2(log2 - x - softplus(-2x)) 
        '''
        
        return 2.0 * (math.log(2.0) - x - nn.functional.softplus(-2.0 * x))

class TanhNormalStable(D.TransformedDistribution):
    def __init__(self, loc, scale, event_dims = 1):
        self._event_dims = event_dims
        self._t = [
            CustomTanhTransform()
        ]
        self.update(loc, scale)

    def log_prob(self, value):
        """
        Scores the sample by inverting the transform(s) and computing the score
        using the score of the base distribution and the log abs det jacobian.
        """
        if self._validate_args:
            self._validate_sample(value)
        event_dim = len(self.event_shape)
        log_prob = 0.0
        y = value
        for transform in reversed(self.transforms):
            x = transform.inv(y)
            event_dim += transform.domain.event_dim - transform.codomain.event_dim
            log_prob = log_prob - D.utils._sum_rightmost(
                transform.log_abs_det_jacobian(x, y),
                event_dim - transform.domain.event_dim,
            )
            y = x

        log_prob = log_prob + D.utils._sum_rightmost(
            self.base_dist.log_prob(y), event_dim - len(self.base_dist.event_shape)
        )

        log_prob = torch.clamp(log_prob, min = math.log10(EPS)) # <- **CLAMPING THIS SEEMS TO RESOLVE THE ISSUE**
        return log_prob

    def update(self, loc: torch.Tensor, scale: torch.Tensor) -> None:
        self.loc = loc
        self.scale = scale
        if (
            hasattr(self, "base_dist")
            and (self.base_dist.base_dist.loc.shape == self.loc.shape)
            and (self.base_dist.base_dist.scale.shape == self.scale.shape)
        ):
            self.base_dist.base_dist.loc = self.loc
            self.base_dist.base_dist.scale = self.scale
        else:
            base = D.Independent(D.Normal(self.loc, self.scale), self._event_dims)
            super().__init__(base, self._t)

    @property
    def mode(self):
        m = self.base_dist.base_dist.mean
        for t in self.transforms:
            m = t(m)
        return m

although such a fix flies in the face of this comment from the PyTorch devs.

Checklist

  • I have checked that there is no similar issue in the repo (required)
  • I have read the documentation (required)
  • I have provided a minimal working example to reproduce the bug (required)

Metadata

Metadata

Assignees

Labels

bugSomething isn't working

Type

No type

Projects

No projects

Milestone

No milestone

Relationships

None yet

Development

No branches or pull requests

Issue actions