Skip to content

Advi not working with pm.censored module  #6285

Open
@kylejcaron

Description

@kylejcaron

Description of your problem

I have hierarchical censored weibull data that can be fit with mcmc, but advi fails. would love to fix this so I can port my model at work over to pymc v4

Please provide a minimal, self-contained, and reproducible example.
Difficult for a minimal survival simulation, apologies

import numpy as np
import pandas as pd
import matplotlib.pyplot as plt
import seaborn as sns
import scipy.stats as stats
import scipy.special as sp
import pymc as pm
import arviz as az
SEED = 99

def sim_global_params(params, N_groups, seed=None):
    np.random.seed(seed)
    params = params.copy()
    
    params["log_lambd"] = np.random.normal(
        params["log_lambd_mu"], 
        params["log_lambd_sig"], 
        size=N_groups)

    params["log_k"] = np.random.normal(
        params["log_k_mu"], 
        params["log_k_sig"], 
        size=N_groups)
    
    return params

def sim_data(params, N = 2500, N_groups = 100, seed=None):
    '''Simulate dataset
    '''
    np.random.seed(seed)
    
    # simulate global params
    params = sim_global_params(params, N_groups, seed=seed)

    # simulate which groups each observation belongs to
    group_idxs = np.random.choice(range(N_groups),size=N)
    # simulate event time data
    y_true = pm.Weibull.dist(
        np.exp(params["log_k"][group_idxs]), 
        np.exp(params["log_lambd"][group_idxs])
        ).eval()

    # randomly censor the dataset to mimic survival analysis
    cens_time = np.random.lognormal(4, 0.75, size=N).astype(int) #np.random.uniform(0, 250, size=N)

    data = (
        pd.DataFrame({
        "group":group_idxs,
        "time": y_true})
        # adjust the dataset to censor observations
        ## indicates if an event hasnt occurred yet (cens=1)
        .assign(cens = lambda d: np.where(d.time <= cens_time, 0, 1) )
        ## indicates the latest time observed for each record
        .assign(time = lambda d: np.where(d.cens==1, cens_time, d.time) )
    )
    
    return data


def hierarchical_normal(name, dims, μ=0., nc=True):
    
    if nc:
        Δ = pm.Normal('Δ_{}'.format(name), 0., 1., dims=dims)
        σ = pm.Exponential('σ_{}'.format(name), 5.)

        return pm.Deterministic(name, μ + Δ * σ)
    
    else:
        mu = pm.Normal("μ_{}".format(name), μ, 1)
        sigma = pm.Exponential("σ_{}".format(name), 5.)
        return pm.Normal(name, mu, sigma, dims=dims)

params = dict(
    log_lambd_mu = np.log(65),
    log_lambd_sig = 0.4,
    log_k_mu = np.log(1.65),
    log_k_sig = 0.2,
)

N_groups = 100
data = sim_data(params, N=2500, N_groups=N_groups, seed=SEED)
# data for model to reference
COORDS = {"group":np.arange(N_groups)}


with pm.Model(coords=COORDS) as m_weibull:
                        
    T = pm.ConstantData("T", data.time.values)
    E = pm.ConstantData("E", np.where(data.cens==1, 0, 1))
    cens_ = pm.ConstantData("cens", np.where(data.cens==1, data.time, np.inf))
    g_ = pm.MutableData("group_idx", data.group.values)
    
    
    mu_log_k = pm.Normal("mu_log_k", 0.5, 0.25)
    mu_log_lambd = pm.Normal("mu_log_lambd", 4.15, 0.25)

    log_k = hierarchical_normal("log_k", μ=mu_log_k, dims="group", nc=True)
    log_lambd = hierarchical_normal("log_lambd", μ=mu_log_lambd, dims="group", nc=True)

    k = pm.Deterministic("k", pm.math.exp(log_k), dims="group")
    lambd = pm.Deterministic("lambd", pm.math.exp(log_lambd), dims="group")

    y_latent = pm.Weibull.dist(k[g_], lambd[g_])

                
    obs = pm.Censored("obs", y_latent,  
                       lower=None, 
                      upper=cens_,
                      observed=T)
    #works
#     idata = pm.sample()
    # doesnt work
    approx = pm.fit(method="advi")

Please provide the full traceback.

Complete error traceback
FloatingPointError                        Traceback (most recent call last)
Input In [34], in <cell line: 5>()
     26     obs = pm.Censored("obs", y_latent,  
     27                        lower=None, 
     28                       upper=cens_,
     29                       observed=T)
     30     #works
     31 #     idata = pm.sample()
     32     # doesnt work
---> 33     approx = pm.fit(method="advi")

File ~/.pyenv/versions/3.9.7/envs/default_venv/lib/python3.9/site-packages/pymc/variational/inference.py:753, in fit(n, method, model, random_seed, start, start_sigma, inf_kwargs, **kwargs)
    751 else:
    752     raise TypeError(f"method should be one of {set(_select.keys())} or Inference instance")
--> 753 return inference.fit(n, **kwargs)

File ~/.pyenv/versions/3.9.7/envs/default_venv/lib/python3.9/site-packages/pymc/variational/inference.py:144, in Inference.fit(self, n, score, callbacks, progressbar, **kwargs)
    142     progress = range(n)
    143 if score:
--> 144     state = self._iterate_with_loss(0, n, step_func, progress, callbacks)
    145 else:
    146     state = self._iterate_without_loss(0, n, step_func, progress, callbacks)

File ~/.pyenv/versions/3.9.7/envs/default_venv/lib/python3.9/site-packages/pymc/variational/inference.py:230, in Inference._iterate_with_loss(self, s, n, step_func, progress, callbacks)
    228     except IndexError:
    229         pass
--> 230     raise FloatingPointError("\n".join(errmsg))
    231 scores[i] = e
    232 if i % 10 == 0:

FloatingPointError: NaN occurred in optimization. 
The current approximation of RV `mu_log_k`.ravel()[0] is NaN.
The current approximation of RV `mu_log_lambd`.ravel()[0] is NaN.
The current approximation of RV `Δ_log_k`.ravel()[0] is NaN.
The current approximation of RV `Δ_log_k`.ravel()[1] is NaN.
The current approximation of RV `Δ_log_k`.ravel()[2] is NaN.
The current approximation of RV `Δ_log_k`.ravel()[3] is NaN.
The current approximation of RV `Δ_log_k`.ravel()[4] is NaN.
The current approximation of RV `Δ_log_k`.ravel()[5] is NaN.
The current approximation of RV `Δ_log_k`.ravel()[6] is NaN.
The current approximation of RV `Δ_log_k`.ravel()[7] is NaN.
The current approximation of RV `Δ_log_k`.ravel()[8] is NaN.
The current approximation of RV `Δ_log_k`.ravel()[9] is NaN.
The current approximation of RV `Δ_log_k`.ravel()[10] is NaN.
The current approximation of RV `Δ_log_k`.ravel()[11] is NaN.
The current approximation of RV `Δ_log_k`.ravel()[12] is NaN.
The current approximation of RV `Δ_log_k`.ravel()[13] is NaN.
The current approximation of RV `Δ_log_k`.ravel()[14] is NaN.
The current approximation of RV `Δ_log_k`.ravel()[15] is NaN.
The current approximation of RV `Δ_log_k`.ravel()[16] is NaN.
The current approximation of RV `Δ_log_k`.ravel()[17] is NaN.
The current approximation of RV `Δ_log_k`.ravel()[18] is NaN.
The current approximation of RV `Δ_log_k`.ravel()[19] is NaN.
The current approximation of RV `Δ_log_k`.ravel()[20] is NaN.
The current approximation of RV `Δ_log_k`.ravel()[21] is NaN.
The current approximation of RV `Δ_log_k`.ravel()[22] is NaN.
The current approximation of RV `Δ_log_k`.ravel()[23] is NaN.
The current approximation of RV `Δ_log_k`.ravel()[24] is NaN.
The current approximation of RV `Δ_log_k`.ravel()[25] is NaN.
The current approximation of RV `Δ_log_k`.ravel()[26] is NaN.
The current approximation of RV `Δ_log_k`.ravel()[27] is NaN.
The current approximation of RV `Δ_log_k`.ravel()[28] is NaN.
The current approximation of RV `Δ_log_k`.ravel()[29] is NaN.
The current approximation of RV `Δ_log_k`.ravel()[30] is NaN.
The current approximation of RV `Δ_log_k`.ravel()[31] is NaN.
The current approximation of RV `Δ_log_k`.ravel()[32] is NaN.
The current approximation of RV `Δ_log_k`.ravel()[33] is NaN.
The current approximation of RV `Δ_log_k`.ravel()[34] is NaN.
The current approximation of RV `Δ_log_k`.ravel()[35] is NaN.
The current approximation of RV `Δ_log_k`.ravel()[36] is NaN.
The current approximation of RV `Δ_log_k`.ravel()[37] is NaN.
The current approximation of RV `Δ_log_k`.ravel()[38] is NaN.
The current approximation of RV `Δ_log_k`.ravel()[39] is NaN.
The current approximation of RV `Δ_log_k`.ravel()[40] is NaN.
The current approximation of RV `Δ_log_k`.ravel()[41] is NaN.
The current approximation of RV `Δ_log_k`.ravel()[42] is NaN.
The current approximation of RV `Δ_log_k`.ravel()[43] is NaN.
The current approximation of RV `Δ_log_k`.ravel()[44] is NaN.
The current approximation of RV `Δ_log_k`.ravel()[45] is NaN.
The current approximation of RV `Δ_log_k`.ravel()[46] is NaN.
The current approximation of RV `Δ_log_k`.ravel()[47] is NaN.
The current approximation of RV `Δ_log_k`.ravel()[48] is NaN.
The current approximation of RV `Δ_log_k`.ravel()[49] is NaN.
The current approximation of RV `Δ_log_k`.ravel()[50] is NaN.
The current approximation of RV `Δ_log_k`.ravel()[51] is NaN.
The current approximation of RV `Δ_log_k`.ravel()[52] is NaN.
The current approximation of RV `Δ_log_k`.ravel()[53] is NaN.
The current approximation of RV `Δ_log_k`.ravel()[54] is NaN.
The current approximation of RV `Δ_log_k`.ravel()[55] is NaN.
The current approximation of RV `Δ_log_k`.ravel()[56] is NaN.
The current approximation of RV `Δ_log_k`.ravel()[57] is NaN.
The current approximation of RV `Δ_log_k`.ravel()[58] is NaN.
The current approximation of RV `Δ_log_k`.ravel()[59] is NaN.
The current approximation of RV `Δ_log_k`.ravel()[60] is NaN.
The current approximation of RV `Δ_log_k`.ravel()[61] is NaN.
The current approximation of RV `Δ_log_k`.ravel()[62] is NaN.
The current approximation of RV `Δ_log_k`.ravel()[63] is NaN.
The current approximation of RV `Δ_log_k`.ravel()[64] is NaN.
The current approximation of RV `Δ_log_k`.ravel()[65] is NaN.
The current approximation of RV `Δ_log_k`.ravel()[66] is NaN.
The current approximation of RV `Δ_log_k`.ravel()[67] is NaN.
The current approximation of RV `Δ_log_k`.ravel()[68] is NaN.
The current approximation of RV `Δ_log_k`.ravel()[69] is NaN.
The current approximation of RV `Δ_log_k`.ravel()[70] is NaN.
The current approximation of RV `Δ_log_k`.ravel()[71] is NaN.
The current approximation of RV `Δ_log_k`.ravel()[72] is NaN.
The current approximation of RV `Δ_log_k`.ravel()[73] is NaN.
The current approximation of RV `Δ_log_k`.ravel()[74] is NaN.
The current approximation of RV `Δ_log_k`.ravel()[75] is NaN.
The current approximation of RV `Δ_log_k`.ravel()[76] is NaN.
The current approximation of RV `Δ_log_k`.ravel()[77] is NaN.
The current approximation of RV `Δ_log_k`.ravel()[78] is NaN.
The current approximation of RV `Δ_log_k`.ravel()[79] is NaN.
The current approximation of RV `Δ_log_k`.ravel()[80] is NaN.
The current approximation of RV `Δ_log_k`.ravel()[81] is NaN.
The current approximation of RV `Δ_log_k`.ravel()[82] is NaN.
The current approximation of RV `Δ_log_k`.ravel()[83] is NaN.
The current approximation of RV `Δ_log_k`.ravel()[84] is NaN.
The current approximation of RV `Δ_log_k`.ravel()[85] is NaN.
The current approximation of RV `Δ_log_k`.ravel()[86] is NaN.
The current approximation of RV `Δ_log_k`.ravel()[87] is NaN.
The current approximation of RV `Δ_log_k`.ravel()[88] is NaN.
The current approximation of RV `Δ_log_k`.ravel()[89] is NaN.
The current approximation of RV `Δ_log_k`.ravel()[90] is NaN.
The current approximation of RV `Δ_log_k`.ravel()[91] is NaN.
The current approximation of RV `Δ_log_k`.ravel()[92] is NaN.
The current approximation of RV `Δ_log_k`.ravel()[93] is NaN.
The current approximation of RV `Δ_log_k`.ravel()[94] is NaN.
The current approximation of RV `Δ_log_k`.ravel()[95] is NaN.
The current approximation of RV `Δ_log_k`.ravel()[96] is NaN.
The current approximation of RV `Δ_log_k`.ravel()[97] is NaN.
The current approximation of RV `Δ_log_k`.ravel()[98] is NaN.
The current approximation of RV `Δ_log_k`.ravel()[99] is NaN.
The current approximation of RV `σ_log_k_log__`.ravel()[0] is NaN.
The current approximation of RV `Δ_log_lambd`.ravel()[0] is NaN.
The current approximation of RV `Δ_log_lambd`.ravel()[1] is NaN.
The current approximation of RV `Δ_log_lambd`.ravel()[2] is NaN.
The current approximation of RV `Δ_log_lambd`.ravel()[3] is NaN.
The current approximation of RV `Δ_log_lambd`.ravel()[4] is NaN.
The current approximation of RV `Δ_log_lambd`.ravel()[5] is NaN.
The current approximation of RV `Δ_log_lambd`.ravel()[6] is NaN.
The current approximation of RV `Δ_log_lambd`.ravel()[7] is NaN.
The current approximation of RV `Δ_log_lambd`.ravel()[8] is NaN.
The current approximation of RV `Δ_log_lambd`.ravel()[9] is NaN.
The current approximation of RV `Δ_log_lambd`.ravel()[10] is NaN.
The current approximation of RV `Δ_log_lambd`.ravel()[11] is NaN.
The current approximation of RV `Δ_log_lambd`.ravel()[12] is NaN.
The current approximation of RV `Δ_log_lambd`.ravel()[13] is NaN.
The current approximation of RV `Δ_log_lambd`.ravel()[14] is NaN.
The current approximation of RV `Δ_log_lambd`.ravel()[15] is NaN.
The current approximation of RV `Δ_log_lambd`.ravel()[16] is NaN.
The current approximation of RV `Δ_log_lambd`.ravel()[17] is NaN.
The current approximation of RV `Δ_log_lambd`.ravel()[18] is NaN.
The current approximation of RV `Δ_log_lambd`.ravel()[19] is NaN.
The current approximation of RV `Δ_log_lambd`.ravel()[20] is NaN.
The current approximation of RV `Δ_log_lambd`.ravel()[21] is NaN.
The current approximation of RV `Δ_log_lambd`.ravel()[22] is NaN.
The current approximation of RV `Δ_log_lambd`.ravel()[23] is NaN.
The current approximation of RV `Δ_log_lambd`.ravel()[24] is NaN.
The current approximation of RV `Δ_log_lambd`.ravel()[25] is NaN.
The current approximation of RV `Δ_log_lambd`.ravel()[26] is NaN.
The current approximation of RV `Δ_log_lambd`.ravel()[27] is NaN.
The current approximation of RV `Δ_log_lambd`.ravel()[28] is NaN.
The current approximation of RV `Δ_log_lambd`.ravel()[29] is NaN.
The current approximation of RV `Δ_log_lambd`.ravel()[30] is NaN.
The current approximation of RV `Δ_log_lambd`.ravel()[31] is NaN.
The current approximation of RV `Δ_log_lambd`.ravel()[32] is NaN.
The current approximation of RV `Δ_log_lambd`.ravel()[33] is NaN.
The current approximation of RV `Δ_log_lambd`.ravel()[34] is NaN.
The current approximation of RV `Δ_log_lambd`.ravel()[35] is NaN.
The current approximation of RV `Δ_log_lambd`.ravel()[36] is NaN.
The current approximation of RV `Δ_log_lambd`.ravel()[37] is NaN.
The current approximation of RV `Δ_log_lambd`.ravel()[38] is NaN.
The current approximation of RV `Δ_log_lambd`.ravel()[39] is NaN.
The current approximation of RV `Δ_log_lambd`.ravel()[40] is NaN.
The current approximation of RV `Δ_log_lambd`.ravel()[41] is NaN.
The current approximation of RV `Δ_log_lambd`.ravel()[42] is NaN.
The current approximation of RV `Δ_log_lambd`.ravel()[43] is NaN.
The current approximation of RV `Δ_log_lambd`.ravel()[44] is NaN.
The current approximation of RV `Δ_log_lambd`.ravel()[45] is NaN.
The current approximation of RV `Δ_log_lambd`.ravel()[46] is NaN.
The current approximation of RV `Δ_log_lambd`.ravel()[47] is NaN.
The current approximation of RV `Δ_log_lambd`.ravel()[48] is NaN.
The current approximation of RV `Δ_log_lambd`.ravel()[49] is NaN.
The current approximation of RV `Δ_log_lambd`.ravel()[50] is NaN.
The current approximation of RV `Δ_log_lambd`.ravel()[51] is NaN.
The current approximation of RV `Δ_log_lambd`.ravel()[52] is NaN.
The current approximation of RV `Δ_log_lambd`.ravel()[53] is NaN.
The current approximation of RV `Δ_log_lambd`.ravel()[54] is NaN.
The current approximation of RV `Δ_log_lambd`.ravel()[55] is NaN.
The current approximation of RV `Δ_log_lambd`.ravel()[56] is NaN.
The current approximation of RV `Δ_log_lambd`.ravel()[57] is NaN.
The current approximation of RV `Δ_log_lambd`.ravel()[58] is NaN.
The current approximation of RV `Δ_log_lambd`.ravel()[59] is NaN.
The current approximation of RV `Δ_log_lambd`.ravel()[60] is NaN.
The current approximation of RV `Δ_log_lambd`.ravel()[61] is NaN.
The current approximation of RV `Δ_log_lambd`.ravel()[62] is NaN.
The current approximation of RV `Δ_log_lambd`.ravel()[63] is NaN.
The current approximation of RV `Δ_log_lambd`.ravel()[64] is NaN.
The current approximation of RV `Δ_log_lambd`.ravel()[65] is NaN.
The current approximation of RV `Δ_log_lambd`.ravel()[66] is NaN.
The current approximation of RV `Δ_log_lambd`.ravel()[67] is NaN.
The current approximation of RV `Δ_log_lambd`.ravel()[68] is NaN.
The current approximation of RV `Δ_log_lambd`.ravel()[69] is NaN.
The current approximation of RV `Δ_log_lambd`.ravel()[70] is NaN.
The current approximation of RV `Δ_log_lambd`.ravel()[71] is NaN.
The current approximation of RV `Δ_log_lambd`.ravel()[72] is NaN.
The current approximation of RV `Δ_log_lambd`.ravel()[73] is NaN.
The current approximation of RV `Δ_log_lambd`.ravel()[74] is NaN.
The current approximation of RV `Δ_log_lambd`.ravel()[75] is NaN.
The current approximation of RV `Δ_log_lambd`.ravel()[76] is NaN.
The current approximation of RV `Δ_log_lambd`.ravel()[77] is NaN.
The current approximation of RV `Δ_log_lambd`.ravel()[78] is NaN.
The current approximation of RV `Δ_log_lambd`.ravel()[79] is NaN.
The current approximation of RV `Δ_log_lambd`.ravel()[80] is NaN.
The current approximation of RV `Δ_log_lambd`.ravel()[81] is NaN.
The current approximation of RV `Δ_log_lambd`.ravel()[82] is NaN.
The current approximation of RV `Δ_log_lambd`.ravel()[83] is NaN.
The current approximation of RV `Δ_log_lambd`.ravel()[84] is NaN.
The current approximation of RV `Δ_log_lambd`.ravel()[85] is NaN.
The current approximation of RV `Δ_log_lambd`.ravel()[86] is NaN.
The current approximation of RV `Δ_log_lambd`.ravel()[87] is NaN.
The current approximation of RV `Δ_log_lambd`.ravel()[88] is NaN.
The current approximation of RV `Δ_log_lambd`.ravel()[89] is NaN.
The current approximation of RV `Δ_log_lambd`.ravel()[90] is NaN.
The current approximation of RV `Δ_log_lambd`.ravel()[91] is NaN.
The current approximation of RV `Δ_log_lambd`.ravel()[92] is NaN.
The current approximation of RV `Δ_log_lambd`.ravel()[93] is NaN.
The current approximation of RV `Δ_log_lambd`.ravel()[94] is NaN.
The current approximation of RV `Δ_log_lambd`.ravel()[95] is NaN.
The current approximation of RV `Δ_log_lambd`.ravel()[96] is NaN.
The current approximation of RV `Δ_log_lambd`.ravel()[97] is NaN.
The current approximation of RV `Δ_log_lambd`.ravel()[98] is NaN.
The current approximation of RV `Δ_log_lambd`.ravel()[99] is NaN.
The current approximation of RV `σ_log_lambd_log__`.ravel()[0] is NaN.
Try tracking this parameter: http://docs.pymc.io/notebooks/variational_api_quickstart.html#Tracking-parameters

Please provide any additional information below.

Versions and main components

  • PyMC/PyMC3 Version: 4.3.0
  • Aesara/Theano Version: 2.8.7
  • Python Version: 3.9.7
  • Operating system: Mac OSX (M1, Darwin)
  • How did you install PyMC/PyMC3: (conda/pip) pip

Metadata

Metadata

Assignees

No one assigned

    Labels

    VIVariational Inferencebug

    Type

    No type

    Projects

    No projects

    Milestone

    No milestone

    Relationships

    None yet

    Development

    No branches or pull requests

    Issue actions