diff --git a/dislib/commons/rf/data.py b/dislib/commons/rf/data.py index 8e4bf546..e5155bdc 100644 --- a/dislib/commons/rf/data.py +++ b/dislib/commons/rf/data.py @@ -412,7 +412,7 @@ def _fill_samples_file(samples_path, row_blocks, start_idx): rows_samples = Array._merge_blocks(row_blocks) rows_samples = rows_samples.astype(dtype="float32", casting="same_kind") samples = np.lib.format.open_memmap(samples_path, mode="r+") - samples[start_idx : start_idx + rows_samples.shape[0]] = rows_samples + samples[start_idx: start_idx + rows_samples.shape[0]] = rows_samples @task(samples_path=FILE_INOUT, row_blocks={Type: COLLECTION_IN, Depth: 2}) @@ -420,7 +420,7 @@ def _fill_features_file(samples_path, row_blocks, start_idx): rows_samples = Array._merge_blocks(row_blocks) rows_samples = rows_samples.astype(dtype="float32", casting="same_kind") samples = np.lib.format.open_memmap(samples_path, mode="r+") - samples[:, start_idx : start_idx + rows_samples.shape[0]] = rows_samples.T + samples[:, start_idx: start_idx + rows_samples.shape[0]] = rows_samples.T @task(targets_path=FILE_INOUT, row_blocks={Type: COLLECTION_IN, Depth: 2}) diff --git a/tests/test_rf_dataset.py b/tests/test_rf_dataset.py index 86eceaf8..de55fc76 100644 --- a/tests/test_rf_dataset.py +++ b/tests/test_rf_dataset.py @@ -143,7 +143,7 @@ def _fill_samples_file(samples_path, row_blocks, start_idx, fortran_order): samples = np.lib.format.open_memmap( samples_path, mode="r+", fortran_order=fortran_order ) - samples[start_idx : start_idx + rows_samples.shape[0]] = rows_samples + samples[start_idx: start_idx + rows_samples.shape[0]] = rows_samples def _fill_features_file(samples_path, row_blocks, start_idx, fortran_order): @@ -152,7 +152,7 @@ def _fill_features_file(samples_path, row_blocks, start_idx, fortran_order): samples = np.lib.format.open_memmap( samples_path, mode="r+", fortran_order=fortran_order ) - samples[:, start_idx : start_idx + rows_samples.shape[0]] = rows_samples.T + samples[:, start_idx: start_idx + rows_samples.shape[0]] = rows_samples.T def _fill_targets_file(targets_path, row_blocks):