Skip to content
Open
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
46 changes: 46 additions & 0 deletions pytmle/estimates.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
import numpy as np
import pandas as pd
from pathlib import Path
from dataclasses import dataclass, field
from typing import Optional, List, Union

Expand Down Expand Up @@ -50,6 +51,51 @@ def __setattr__(self, name, value):
self._check_compatibility(value, check_width=True)
super().__setattr__(name, value)

@classmethod
def from_path(cls, data: pd.DataFrame, path: str | Path):
if type(path) != Path:
path = Path(path)
surv_0_df = pd.read_csv(path / "surv_0.csv")
surv_1_df = pd.read_csv(path / "surv_1.csv")
cens_surv_0_df = pd.read_csv(path / "cens_surv_0.csv")
cens_surv_1_df = pd.read_csv(path / "cens_surv_1.csv")
haz1_0_df = pd.read_csv(path / "haz1_0.csv")
haz1_1_df = pd.read_csv(path / "haz1_1.csv")
haz2_0_df = pd.read_csv(path / "haz2_0.csv")
haz2_1_df = pd.read_csv(path / "haz2_1.csv")
prop_0_df = pd.read_csv(path / "prop_0.csv", header=None)
prop_1_df = pd.read_csv(path / "prop_1.csv", header=None)

# List of dataframes to check
dataframes = [surv_0_df, surv_1_df, cens_surv_0_df, cens_surv_1_df, haz1_0_df, haz1_1_df, haz2_0_df, haz2_1_df]

# Check if all dataframes have the same columns
all_col_equal = all(df.columns.equals(dataframes[0].columns) for df in dataframes)

assert all_col_equal, "All columns must be equal."

dataframes += [data, prop_0_df, prop_1_df]

# Check if all dataframes have the same index
all_indices_equal = all(df.index.equals(dataframes[0].index) for df in dataframes)

assert all_indices_equal, "All indices must be equal"

return {
0: cls(times = surv_0_df.columns.astype(float),
g_star_obs= 1 - data["chemo"].values,
propensity_scores=prop_0_df.values.squeeze(),
hazards=np.stack([haz1_0_df.values, haz2_0_df.values], axis=-1),
event_free_survival_function=surv_0_df.values,
censoring_survival_function=cens_surv_0_df.values),
1: cls(times = surv_1_df.columns.astype(float),
g_star_obs= data["chemo"].values,
propensity_scores=prop_1_df.values.squeeze(),
hazards=np.stack([haz1_1_df.values, haz2_1_df.values], axis=-1),
event_free_survival_function=surv_1_df.values,
censoring_survival_function=cens_surv_1_df.values),
}

def _check_compatibility(self, new_element, check_width):
# check that all given estimates have the same length (first dimension size)
if self._length is None:
Expand Down