Skip to content

Commit 41d6c9c

Browse files
bartosz-grabowskiericspodKumoLiu
authored
8201 Fix DataFrame subsets indexing in CSVDataset() (#8351)
Fixes #8201 . ### Description `convert_tables_to_dicts` was using `.loc` to index DataFrames which was changed to `.iloc`. It was causing unexpected behavior in `CSVDataset` as demonstrated in #8201 because `.loc` expects labels, but was instead provided positions of the rows. Unittest was added which fails before the change and passes afterwards. ### Types of changes - [x] Non-breaking change (fix or new feature that would not break existing functionality). - [ ] Breaking change (fix or new feature that would cause existing functionality to change). - [x] New tests added to cover the changes. - [ ] Integration tests passed locally by running `./runtests.sh -f -u --net --coverage`. - [ ] Quick tests passed locally by running `./runtests.sh --quick --unittests --disttests`. - [ ] In-line docstrings updated. - [ ] Documentation updated, tested `make html` command in the `docs/` folder. --------- Signed-off-by: Bartosz Grabowski <58475557+bartosz-grabowski@users.noreply.github.com> Co-authored-by: Eric Kerfoot <17726042+ericspod@users.noreply.github.com> Co-authored-by: YunLiu <55491388+KumoLiu@users.noreply.github.com>
1 parent 34f3797 commit 41d6c9c

File tree

2 files changed

+15
-1
lines changed

2 files changed

+15
-1
lines changed

monai/data/utils.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1473,7 +1473,7 @@ def convert_tables_to_dicts(
14731473
# parse row indices
14741474
rows: list[int | str] = []
14751475
if row_indices is None:
1476-
rows = slice(df.shape[0]) # type: ignore
1476+
rows = df.index.tolist()
14771477
else:
14781478
for i in row_indices:
14791479
if isinstance(i, (tuple, list)):

tests/data/test_csv_dataset.py

Lines changed: 14 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -179,6 +179,20 @@ def prepare_csv_file(data, filepath):
179179
},
180180
)
181181

182+
# test pre-loaded DataFrame subset
183+
df = pd.read_csv(filepath1)
184+
df_subset = df.iloc[[1, 3, 4]]
185+
dataset = CSVDataset(src=df_subset, col_groups={"ehr": [f"ehr_{i}" for i in range(3)]})
186+
self.assertEqual(len(dataset), 3)
187+
np.testing.assert_allclose([round(i, 4) for i in dataset[1]["ehr"]], [3.3333, 3.2353, 3.4000])
188+
189+
# test pre-loaded DataFrame subset with row_indices != None
190+
df = pd.read_csv(filepath1)
191+
df_subset = df.iloc[[1, 3, 4]]
192+
dataset = CSVDataset(src=df_subset, row_indices=[1, 3], col_groups={"ehr": [f"ehr_{i}" for i in range(3)]})
193+
self.assertEqual(len(dataset), 2)
194+
np.testing.assert_allclose([round(i, 4) for i in dataset[1]["ehr"]], [3.3333, 3.2353, 3.4000])
195+
182196
# test pre-loaded multiple DataFrames, join tables with kwargs
183197
dfs = [pd.read_csv(i) for i in filepaths]
184198
dataset = CSVDataset(src=dfs, on="subject_id")

0 commit comments

Comments
 (0)