Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
17 changes: 12 additions & 5 deletions src/lymph/models/bilateral.py
Original file line number Diff line number Diff line change
Expand Up @@ -404,6 +404,10 @@ def load_patient_data(
"""
self.ipsi.load_patient_data(patient_data, "ipsi", mapping)
self.contra.load_patient_data(patient_data, "contra", mapping)
# Keep all columns except '_model', but from '_model' only keep those with first subheader '#'
cols = [col for col in self.ipsi.patient_data.columns if col[0] != '_model']
cols += [col for col in self.ipsi.patient_data.columns if col[0] == '_model' and col[1] == '#']
self.patient_data = self.ipsi.patient_data[cols]

def state_dist(
self,
Expand Down Expand Up @@ -472,11 +476,14 @@ def patient_likelihoods(
mode: Literal["HMM", "BN"] = "HMM",
) -> np.ndarray:
"""Compute the likelihood of each patient individually."""
joint_state_dist = self.state_dist(t_stage=t_stage, mode=mode)
return matrix.fast_trace(
self.ipsi.diagnosis_matrix(t_stage),
joint_state_dist @ self.contra.diagnosis_matrix(t_stage).T,
)
if mode == 'HMM':
joint_state_dist = self.state_dist(t_stage=t_stage, mode=mode)
return matrix.fast_trace(
self.ipsi.diagnosis_matrix(t_stage),
joint_state_dist @ self.contra.diagnosis_matrix(t_stage).T,
)
else:
warnings.warn("Only HMM implemented for patient likelihoods.",)

def _bn_likelihood(self, log: bool = True, t_stage: str | None = None) -> float:
"""Compute the BN likelihood of data, using the stored params."""
Expand Down
45 changes: 45 additions & 0 deletions src/lymph/models/midline.py
Original file line number Diff line number Diff line change
Expand Up @@ -673,6 +673,51 @@ def obs_dist(
]
return np.stack(obs_dist)

def patient_likelihoods(
self,
t_stage: str = None,
mode: Literal["HMM", "BN"] = "HMM",
) -> np.ndarray:
if mode != "HMM":
raise NotImplementedError("Only HMM mode is supported as of now.")
ipsi_dist_evo = self.ext.ipsi.state_dist_evo()
contra_dist_evo = {}
contra_dist_evo["noext"], contra_dist_evo["ext"] = self.contra_state_dist_evo()
t_stages = self.t_stages if t_stage is None else [t_stage]
patient_data = self.patient_data.loc[self.patient_data[T_STAGE_COL].isin(t_stages)]
patient_llhs = np.zeros(len(patient_data))
for stage in t_stages:
t_idx = patient_data[T_STAGE_COL] == stage
diag_time_matrix = np.diag(self.get_distribution(stage).pmf)
num_states = ipsi_dist_evo.shape[1]
marg_joint_state_dist = np.zeros(shape=(num_states, num_states))
# see the `Bilateral` model for why this is done in this way.
for case in ["ext", "noext"]:
ext_idx = patient_data[EXT_COL] == (case == "ext")
joint_state_dist = (
ipsi_dist_evo.T @ diag_time_matrix @ contra_dist_evo[case]
)
marg_joint_state_dist += joint_state_dist
_model = getattr(self, case)
llhs = matrix.fast_trace(
_model.ipsi.diagnosis_matrix(stage),
joint_state_dist @ _model.contra.diagnosis_matrix(stage).T,
)
patient_llhs[t_idx & ext_idx] = llhs

try:
marg_patient_llhs = matrix.fast_trace(
self.unknown.ipsi.diagnosis_matrix(stage),
marg_joint_state_dist
@ self.unknown.contra.diagnosis_matrix(stage).T,
)
patient_llhs[t_idx & patient_data[EXT_COL].isna()] = marg_patient_llhs
except AttributeError:
# an AttributeError is raised both when the model has no `unknown`
# attribute and when no data is loaded in the `unknown` model.
pass
return patient_llhs

def _hmm_likelihood(
self,
log: bool = True,
Expand Down
15 changes: 14 additions & 1 deletion src/lymph/models/unilateral.py
Original file line number Diff line number Diff line change
Expand Up @@ -678,7 +678,20 @@ def obs_dist(
given_state_dist = self.state_dist(t_stage=t_stage, mode=mode)

return given_state_dist @ self.observation_matrix()


def patient_likelihoods(
self,
t_stage: str,
mode: Literal["HMM", "BN"] = "HMM",
) -> np.ndarray:
"""Compute the likelihood of each patient individually."""

if mode == "HMM":
return (self.state_dist(t_stage) @ self.diagnosis_matrix(t_stage).T
)
else:
warnings.warn("Only HMM implemented for patient likelihoods.",)

def _bn_likelihood(self, log: bool = True, t_stage: str | None = None) -> float:
"""Compute the BN likelihood, using the stored params."""
state_dist = self.state_dist(mode="BN")
Expand Down