Skip to content

Commit

Permalink
Added creation of features file
Browse files Browse the repository at this point in the history
  • Loading branch information
gcasadesus committed Aug 6, 2021
1 parent 42893b7 commit 0445030
Showing 1 changed file with 48 additions and 19 deletions.
67 changes: 48 additions & 19 deletions dislib/commons/rf/data.py
Original file line number Diff line number Diff line change
Expand Up @@ -253,7 +253,7 @@ def get_classes(self):


def transform_to_rf_dataset(
x: Array, y: Array, task: str
x: Array, y: Array, task: str, features_file=False
) -> RfRegressorDataset or RfClassifierDataset:
"""Creates a RfDataset object from samples x and targets y.
Expand All @@ -277,6 +277,7 @@ def transform_to_rf_dataset(
n_samples = x.shape[0]
n_features = x.shape[1]

# Samples
samples_file = tempfile.NamedTemporaryFile(
mode="wb", prefix="tmp_rf_samples_", delete=False
)
Expand All @@ -293,6 +294,7 @@ def transform_to_rf_dataset(
_fill_samples_file(samples_path, x_row._blocks, start_idx)
start_idx += x._reg_shape[0]

# Targets
targets_file = tempfile.NamedTemporaryFile(
mode="w", prefix="tmp_rf_targets_", delete=False
)
Expand All @@ -301,10 +303,34 @@ def transform_to_rf_dataset(
for y_row in y._iterator(axis=0):
_fill_targets_file(targets_path, y_row._blocks)

# Features
if features_file:
features_file = tempfile.NamedTemporaryFile(
mode="wb", prefix="tmp_rf_features_", delete=False
)
features_path = features_file.name
features_file.close()
_allocate_features_file(features_path, n_samples, n_features)

start_idx = 0
row_blocks_iterator = x._iterator(axis=0)
top_row = next(row_blocks_iterator)
_fill_features_file(features_path, top_row._blocks, start_idx)
start_idx += x._top_left_shape[0]
for x_row in row_blocks_iterator:
_fill_features_file(features_path, x_row._blocks, start_idx)
start_idx += x._reg_shape[0]
else:
features_path = None

if task == "classification":
rf_dataset = RfClassifierDataset(samples_path, targets_path)
rf_dataset = RfClassifierDataset(
samples_path, targets_path, features_path
)
elif task == "regression":
rf_dataset = RfRegressorDataset(samples_path, targets_path)
rf_dataset = RfRegressorDataset(
samples_path, targets_path, features_path
)
else:
raise ValueError("task must be either classification or regression.")
rf_dataset.n_samples = n_samples
Expand Down Expand Up @@ -361,21 +387,6 @@ def _get_values(targets_path):
return y.astype(np.float64)


@task(returns=1)
def _get_samples_shape(subset):
return subset.samples.shape


@task(returns=3)
def _merge_shapes(*samples_shapes):
n_samples = 0
n_features = samples_shapes[0][1]
for shape in samples_shapes:
n_samples += shape[0]
assert shape[1] == n_features, "Subsamples with different n_features."
return samples_shapes, n_samples, n_features


@task(samples_path=FILE_INOUT)
def _allocate_samples_file(samples_path, n_samples, n_features):
np.lib.format.open_memmap(
Expand All @@ -386,12 +397,30 @@ def _allocate_samples_file(samples_path, n_samples, n_features):
)


@task(samples_path=FILE_INOUT)
def _allocate_features_file(samples_path, n_samples, n_features):
np.lib.format.open_memmap(
samples_path,
mode="w+",
dtype="float32",
shape=(int(n_features), int(n_samples)),
)


@task(samples_path=FILE_INOUT, row_blocks={Type: COLLECTION_IN, Depth: 2})
def _fill_samples_file(samples_path, row_blocks, start_idx):
rows_samples = Array._merge_blocks(row_blocks)
rows_samples = rows_samples.astype(dtype="float32", casting="same_kind")
samples = np.lib.format.open_memmap(samples_path, mode="r+")
samples[start_idx: start_idx + rows_samples.shape[0]] = rows_samples
samples[start_idx : start_idx + rows_samples.shape[0]] = rows_samples


@task(samples_path=FILE_INOUT, row_blocks={Type: COLLECTION_IN, Depth: 2})
def _fill_features_file(samples_path, row_blocks, start_idx):
rows_samples = Array._merge_blocks(row_blocks)
rows_samples = rows_samples.astype(dtype="float32", casting="same_kind")
samples = np.lib.format.open_memmap(samples_path, mode="r+")
samples[:, start_idx : start_idx + rows_samples.shape[0]] = rows_samples.T


@task(targets_path=FILE_INOUT, row_blocks={Type: COLLECTION_IN, Depth: 2})
Expand Down

0 comments on commit 0445030

Please sign in to comment.