-
Notifications
You must be signed in to change notification settings - Fork 3
Description
Description
When a state variable is only used in transition functions (not in utility or constraints), and the regime transitions to a period for which that particular transition is irrelevant, Model construction fails with ValueError: list.index(x): x not in list in vmap_1d.
Root cause
At the last period of a non-terminal regime (where the only active target for the next period is a terminal regime like "dead"), the Q_and_F function's parameter list is computed via _get_arg_names_of_Q_and_F, which takes the union of parameters across all dependency functions. Since state transitions are not needed when transitioning to a terminal regime (which has no states), any state variable that only appears in transition functions is pruned from the Q_and_F signature.
However, build_argmax_and_max_Q_over_a_functions calls simulation_spacemap with states_names=tuple(state_action_space.states), which still includes all states defined in the current regime. When vmap_1d tries to find the pruned state variable's position in the Q_and_F function's parameters, it fails because the variable is not in the signature.
The mismatch is between:
state_action_space.states(all states declared for the regime — includestype_var)- Q_and_F function signature at last period (only states actually needed — excludes
type_var)
Minimal reproducer
import jax.numpy as jnp
from lcm import AgeGrid, DiscreteGrid, LinSpacedGrid, Model, Regime, categorical
@categorical
class RegimeId:
alive: int
dead: int
@categorical
class TypeVar:
low: int
high: int
def utility(consumption, wealth):
return jnp.log(consumption) + 0.01 * wealth
def dead_utility():
return 0.0
def next_wealth(wealth, consumption, type_var):
"""type_var affects wealth transition but does NOT appear in utility."""
return (1 + 0.05 * type_var) * (wealth - consumption)
def next_type_var(type_var):
return type_var
def next_regime(age):
return jnp.where(age >= 2, RegimeId.dead, RegimeId.alive)
ages = AgeGrid(start=0, stop=3, step="Y")
alive = Regime(
utility=utility,
states={
"wealth": LinSpacedGrid(start=1, stop=100, n_points=10),
"type_var": DiscreteGrid(TypeVar),
},
actions={
"consumption": LinSpacedGrid(start=1, stop=50, n_points=10),
},
transitions={
"next_wealth": next_wealth,
"next_type_var": next_type_var,
"next_regime": next_regime,
},
active=lambda age: age <= 2,
)
dead = Regime(terminal=True, utility=dead_utility)
model = Model(
regimes={"alive": alive, "dead": dead},
ages=ages,
regime_id_class=RegimeId,
)Stack trace
File "lcm/model.py", line 108, in __init__
self.internal_regimes = process_regimes(
File "lcm/input_processing/regime_processing.py", line 185, in process_regimes
argmax_and_max_Q_over_a_functions = build_argmax_and_max_Q_over_a_functions(
File "lcm/input_processing/regime_components.py", line 142, in build_argmax_and_max_Q_over_a_functions
argmax_and_max_Q_over_a_functions[period] = simulation_spacemap(
File "lcm/dispatchers.py", line 68, in simulation_spacemap
vmapped = vmap_1d(vmapped, variables=states_names, callable_with="only_args")
File "lcm/dispatchers.py", line 116, in vmap_1d
positions = [parameters.index(var) for var in variables]
ValueError: list.index(x): x not in list
The variables tuple contains all regime states (including type_var), but the Q_and_F function's parameters list at the last period does not contain type_var because it was only used in transition functions that are irrelevant when the only next-period target is a terminal regime.
Version
pylcm 0.0.2.dev60+g008e443a5 (current main or any dev branch)