Skip to content
Merged
2 changes: 1 addition & 1 deletion monai/data/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -1473,7 +1473,7 @@ def convert_tables_to_dicts(
# parse row indices
rows: list[int | str] = []
if row_indices is None:
rows = slice(df.shape[0]) # type: ignore
rows = df.index.tolist()
else:
for i in row_indices:
if isinstance(i, (tuple, list)):
Expand Down
14 changes: 14 additions & 0 deletions tests/data/test_csv_dataset.py
Original file line number Diff line number Diff line change
Expand Up @@ -179,6 +179,20 @@ def prepare_csv_file(data, filepath):
},
)

# test pre-loaded DataFrame subset
df = pd.read_csv(filepath1)
df_subset = df.iloc[[1, 3, 4]]
dataset = CSVDataset(src=df_subset, col_groups={"ehr": [f"ehr_{i}" for i in range(3)]})
self.assertEqual(len(dataset), 3)
np.testing.assert_allclose([round(i, 4) for i in dataset[1]["ehr"]], [3.3333, 3.2353, 3.4000])

# test pre-loaded DataFrame subset with row_indices != None
df = pd.read_csv(filepath1)
df_subset = df.iloc[[1, 3, 4]]
dataset = CSVDataset(src=df_subset, row_indices=[1, 3], col_groups={"ehr": [f"ehr_{i}" for i in range(3)]})
self.assertEqual(len(dataset), 2)
np.testing.assert_allclose([round(i, 4) for i in dataset[1]["ehr"]], [3.3333, 3.2353, 3.4000])

# test pre-loaded multiple DataFrames, join tables with kwargs
dfs = [pd.read_csv(i) for i in filepaths]
dataset = CSVDataset(src=dfs, on="subject_id")
Expand Down
Loading