From b3e83a98337d12b24d3b6ca4973d3e986c1dc21e Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Alexander=20M=C3=A4rz?= Date: Mon, 28 Aug 2023 16:23:38 +0200 Subject: [PATCH] Update stabilize_derivative --- lightgbmlss/distributions/mixture_distribution_utils.py | 7 +++++++ 1 file changed, 7 insertions(+) diff --git a/lightgbmlss/distributions/mixture_distribution_utils.py b/lightgbmlss/distributions/mixture_distribution_utils.py index c4d5afe..9d49360 100644 --- a/lightgbmlss/distributions/mixture_distribution_utils.py +++ b/lightgbmlss/distributions/mixture_distribution_utils.py @@ -596,6 +596,13 @@ def stabilize_derivative(self, input_der: torch.Tensor, type: str = "MAD") -> to div = torch.where(div > torch.tensor(10000.0), torch.tensor(10000.0), div) stab_der = input_der / div + if type == "None": + stab_der = torch.nan_to_num(input_der, + nan=float(torch.nanmean(input_der)), + posinf=float(torch.nanmean(input_der)), + neginf=float(torch.nanmean(input_der)) + ) + return stab_der def dist_select(self,