From ff62b9b14dcf783c7725df3d2247e8ae5eeedf25 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Guillem=20Casades=C3=BAs=20Vila?= Date: Thu, 22 Jul 2021 15:09:27 +0200 Subject: [PATCH] Added DecisionTreeRegressor with MSE criterion --- dislib/regression/rf/decision_tree.py | 397 ++++++++++++++------------ dislib/regression/rf/test_split.py | 28 +- 2 files changed, 234 insertions(+), 191 deletions(-) diff --git a/dislib/regression/rf/decision_tree.py b/dislib/regression/rf/decision_tree.py index 0725fcfa..43ecaf79 100644 --- a/dislib/regression/rf/decision_tree.py +++ b/dislib/regression/rf/decision_tree.py @@ -5,14 +5,14 @@ from pycompss.api.api import compss_delete_object from pycompss.api.parameter import FILE_IN, Type, COLLECTION_IN, Depth from pycompss.api.task import task -from sklearn.tree import DecisionTreeClassifier as SklearnDTClassifier +from sklearn.tree import DecisionTreeRegressor as SklearnDTRegressor -from dislib.classification.rf.test_split import test_split +from dislib.regression.rf.test_split import test_split from dislib.data.array import Array -class DecisionTreeClassifier: - """A distributed decision tree classifier. +class DecisionTreeRegressor: + """A distributed decision tree regressor. Parameters ---------- @@ -39,9 +39,6 @@ class DecisionTreeClassifier: n_features : int The number of features of the dataset. It can be a pycompss.runtime.Future object. - n_classes : int - The number of classes of this RfDataset. It can be a - pycompss.runtime.Future object. tree : None or _Node The root node of the tree after the tree is fitted. nodes_info : None or list of _InnerNodeInfo and _LeafInfo @@ -56,7 +53,7 @@ class DecisionTreeClassifier: Methods ------- fit(dataset) - Fits the DecisionTreeClassifier. + Fits the DecisionTreeRegressor. predict(x_row) Predicts classes for the given samples using a fitted tree. predict_proba(x_row) @@ -64,8 +61,15 @@ class DecisionTreeClassifier: """ - def __init__(self, try_features, max_depth, distr_depth, sklearn_max, - bootstrap, random_state): + def __init__( + self, + try_features, + max_depth, + distr_depth, + sklearn_max, + bootstrap, + random_state, + ): self.try_features = try_features self.max_depth = max_depth self.distr_depth = distr_depth @@ -74,14 +78,13 @@ def __init__(self, try_features, max_depth, distr_depth, sklearn_max, self.random_state = random_state self.n_features = None - self.n_classes = None self.tree = None self.nodes_info = None self.subtrees = None def fit(self, dataset): - """Fits the DecisionTreeClassifier. + """Fits the DecisionTreeRegressor. Parameters ---------- @@ -90,16 +93,16 @@ def fit(self, dataset): """ self.n_features = dataset.get_n_features() - self.n_classes = dataset.get_n_classes() samples_path = dataset.samples_path features_path = dataset.features_path n_samples = dataset.get_n_samples() - y_codes = dataset.get_y_codes() + y_targets = dataset.get_y_targets() seed = self.random_state.randint(np.iinfo(np.int32).max) - sample, y_s = _sample_selection(n_samples, y_codes, self.bootstrap, - seed) + sample, y_s = _sample_selection( + n_samples, y_targets, self.bootstrap, seed + ) self.tree = _Node() self.nodes_info = [] @@ -108,11 +111,15 @@ def fit(self, dataset): while tree_traversal: node, sample, y_s, depth = tree_traversal.pop() if depth < self.distr_depth: - split = _split_node_wrapper(sample, self.n_features, y_s, - self.n_classes, self.try_features, - self.random_state, - samples_file=samples_path, - features_file=features_path) + split = _split_node_wrapper( + sample, + self.n_features, + y_s, + self.try_features, + self.random_state, + samples_file=samples_path, + features_file=features_path, + ) node_info, left_group, y_l, right_group, y_r = split compss_delete_object(sample) compss_delete_object(y_s) @@ -124,13 +131,17 @@ def fit(self, dataset): tree_traversal.append((node.right, right_group, y_r, depth)) tree_traversal.append((node.left, left_group, y_l, depth)) else: - subtree = _build_subtree_wrapper(sample, y_s, self.n_features, - self.max_depth - depth, - self.n_classes, - self.try_features, - self.sklearn_max, - self.random_state, - samples_path, features_path) + subtree = _build_subtree_wrapper( + sample, + y_s, + self.n_features, + self.max_depth - depth, + self.try_features, + self.sklearn_max, + self.random_state, + samples_path, + features_path, + ) node.content = len(self.subtrees) self.subtrees.append(subtree) compss_delete_object(sample) @@ -155,47 +166,23 @@ def predict(self, x_row): """ - assert self.tree is not None, 'The decision tree is not fitted.' + assert self.tree is not None, "The decision tree is not fitted." branch_predictions = [] for i, subtree in enumerate(self.subtrees): - pred = _predict_branch(x_row._blocks, self.tree, self.nodes_info, - i, subtree, self.distr_depth) + pred = _predict_branch( + x_row._blocks, + self.tree, + self.nodes_info, + i, + subtree, + self.distr_depth, + ) branch_predictions.append(pred) return _merge_branches(None, *branch_predictions) - def predict_proba(self, x_row): - """Predicts class probabilities for a row block using a fitted tree. - - Parameters - ---------- - x_row : ds-array - A row block of samples. - - Returns - ------- - predicted_proba : ndarray - An array with the predicted probabilities for the given samples. - The shape is (len(subset.samples), self.n_classes), with the index - of the column being codes of the fitted - dislib.classification.rf.data.RfDataset. The returned object can be - a pycompss.runtime.Future object. - - """ - - assert self.tree is not None, 'The decision tree is not fitted.' - - branch_predictions = [] - for i, subtree in enumerate(self.subtrees): - pred = _predict_branch_proba(x_row._blocks, self.tree, - self.nodes_info, i, subtree, - self.distr_depth, self.n_classes) - branch_predictions.append(pred) - return _merge_branches(self.n_classes, *branch_predictions) - class _Node: - def __init__(self): self.content = None self.left = None @@ -204,7 +191,7 @@ def __init__(self): def predict(self, sample): node_content = self.content if isinstance(node_content, _LeafInfo): - return np.full((len(sample),), node_content.mode) + return np.full((len(sample),), node_content.mean) if isinstance(node_content, _SkTreeWrapper): if len(sample) > 0: return node_content.sk_tree.predict(sample) @@ -214,29 +201,9 @@ def predict(self, sample): pred[left_mask] = self.left.predict(sample[left_mask]) pred[~left_mask] = self.right.predict(sample[~left_mask]) return pred - assert len(sample) == 0, 'Type not supported' + assert len(sample) == 0, "Type not supported" return np.empty((0,), dtype=np.int64) - def predict_proba(self, sample, n_classes): - node_content = self.content - if isinstance(node_content, _LeafInfo): - single_pred = node_content.frequencies / node_content.size - return np.tile(single_pred, (len(sample), 1)) - if isinstance(node_content, _SkTreeWrapper): - if len(sample) > 0: - sk_tree_pred = node_content.sk_tree.predict_proba(sample) - pred = np.zeros((len(sample), n_classes), dtype=np.float64) - pred[:, node_content.sk_tree.classes_] = sk_tree_pred - return pred - if isinstance(node_content, _InnerNodeInfo): - pred = np.empty((len(sample), n_classes), dtype=np.float64) - l_msk = sample[:, node_content.index] <= node_content.value - pred[l_msk] = self.left.predict_proba(sample[l_msk], n_classes) - pred[~l_msk] = self.right.predict_proba(sample[~l_msk], n_classes) - return pred - assert len(sample) == 0, 'Type not supported' - return np.empty((0, n_classes), dtype=np.float64) - class _InnerNodeInfo: def __init__(self, index=None, value=None): @@ -245,10 +212,9 @@ def __init__(self, index=None, value=None): class _LeafInfo: - def __init__(self, size=None, frequencies=None, mode=None): + def __init__(self, size=None, mean=None): self.size = size - self.frequencies = frequencies - self.mode = mode + self.mean = mean class _SkTreeWrapper: @@ -258,7 +224,7 @@ def __init__(self, tree): def _get_sample_attributes(samples_file, indices): - samples_mmap = np.load(samples_file, mmap_mode='r', allow_pickle=False) + samples_mmap = np.load(samples_file, mmap_mode="r", allow_pickle=False) x = samples_mmap[indices] return x @@ -268,25 +234,27 @@ def _get_feature_mmap(features_file, i): def _get_features_mmap(features_file): - return np.load(features_file, mmap_mode='r', allow_pickle=False) + return np.load(features_file, mmap_mode="r", allow_pickle=False) @task(priority=True, returns=2) -def _sample_selection(n_samples, y_codes, bootstrap, seed): +def _sample_selection(n_samples, y_targets, bootstrap, seed): if bootstrap: random_state = RandomState(seed) - selection = random_state.choice(n_samples, size=n_samples, - replace=True) + selection = random_state.choice( + n_samples, size=n_samples, replace=True + ) selection.sort() - return selection, y_codes[selection] + return selection, y_targets[selection] else: - return np.arange(n_samples), y_codes + return np.arange(n_samples), y_targets def _feature_selection(untried_indices, m_try, random_state): selection_len = min(m_try, len(untried_indices)) - return random_state.choice(untried_indices, size=selection_len, - replace=False) + return random_state.choice( + untried_indices, size=selection_len, replace=False + ) def _get_groups(sample, y_s, features_mmap, index, value): @@ -303,59 +271,71 @@ def _get_groups(sample, y_s, features_mmap, index, value): return left, y_l, right, y_r -def _compute_leaf_info(y_s, n_classes): - frequencies = np.bincount(y_s, minlength=n_classes) - mode = np.argmax(frequencies) - return _LeafInfo(len(y_s), frequencies, mode) +def _compute_leaf_info(y_s): + return _LeafInfo(len(y_s), np.mean(y_s)) -def _split_node_wrapper(sample, n_features, y_s, n_classes, m_try, - random_state, samples_file=None, features_file=None): +def _split_node_wrapper( + sample, + n_features, + y_s, + m_try, + random_state, + samples_file=None, + features_file=None, +): seed = random_state.randint(np.iinfo(np.int32).max) if features_file is not None: - return _split_node_using_features(sample, n_features, y_s, n_classes, - m_try, features_file, seed) + return _split_node_using_features( + sample, n_features, y_s, m_try, features_file, seed + ) elif samples_file is not None: - return _split_node(sample, n_features, y_s, n_classes, m_try, - samples_file, seed) + return _split_node(sample, n_features, y_s, m_try, samples_file, seed) else: - raise ValueError('Invalid combination of arguments. samples_file is ' - 'None and features_file is None.') + raise ValueError( + "Invalid combination of arguments. samples_file is " + "None and features_file is None." + ) @task(features_file=FILE_IN, returns=(object, list, list, list, list)) -def _split_node_using_features(sample, n_features, y_s, n_classes, m_try, - features_file, seed): - features_mmap = np.load(features_file, mmap_mode='r', allow_pickle=False) +def _split_node_using_features( + sample, n_features, y_s, m_try, features_file, seed +): + features_mmap = np.load(features_file, mmap_mode="r", allow_pickle=False) random_state = RandomState(seed) - return _compute_split(sample, n_features, y_s, n_classes, m_try, - features_mmap, random_state) + return _compute_split( + sample, n_features, y_s, m_try, features_mmap, random_state + ) @task(samples_file=FILE_IN, returns=(object, list, list, list, list)) -def _split_node(sample, n_features, y_s, n_classes, m_try, samples_file, seed): - features_mmap = np.load(samples_file, mmap_mode='r', allow_pickle=False).T +def _split_node(sample, n_features, y_s, m_try, samples_file, seed): + features_mmap = np.load(samples_file, mmap_mode="r", allow_pickle=False).T random_state = RandomState(seed) - return _compute_split(sample, n_features, y_s, n_classes, m_try, - features_mmap, random_state) + return _compute_split( + sample, n_features, y_s, m_try, features_mmap, random_state + ) -def _compute_split(sample, n_features, y_s, n_classes, m_try, features_mmap, - random_state): +def _compute_split( + sample, n_features, y_s, m_try, features_mmap, random_state +): node_info = left_group = y_l = right_group = y_r = None split_ended = False tried_indices = [] while not split_ended: untried_indices = np.setdiff1d(np.arange(n_features), tried_indices) - index_selection = _feature_selection(untried_indices, m_try, - random_state) + index_selection = _feature_selection( + untried_indices, m_try, random_state + ) b_score = float_info.max b_index = None b_value = None for index in index_selection: feature = features_mmap[index] - score, value = test_split(sample, y_s, feature, n_classes) + score, value = test_split(sample, y_s, feature) if score < b_score: b_score, b_value, b_index = score, value, index groups = _get_groups(sample, y_s, features_mmap, b_index, b_value) @@ -367,7 +347,7 @@ def _compute_split(sample, n_features, y_s, n_classes, m_try, features_mmap, tried_indices.extend(list(index_selection)) if len(tried_indices) == n_features: split_ended = True - node_info = _compute_leaf_info(y_s, n_classes) + node_info = _compute_leaf_info(y_s) left_group = sample y_l = y_s right_group = np.array([], dtype=np.int64) @@ -376,48 +356,111 @@ def _compute_split(sample, n_features, y_s, n_classes, m_try, features_mmap, return node_info, left_group, y_l, right_group, y_r -def _build_subtree_wrapper(sample, y_s, n_features, max_depth, n_classes, - m_try, sklearn_max, random_state, samples_file, - features_file): +def _build_subtree_wrapper( + sample, + y_s, + n_features, + max_depth, + m_try, + sklearn_max, + random_state, + samples_file, + features_file, +): seed = random_state.randint(np.iinfo(np.int32).max) if features_file is not None: - return _build_subtree_using_features(sample, y_s, n_features, - max_depth, n_classes, m_try, - sklearn_max, seed, samples_file, - features_file) + return _build_subtree_using_features( + sample, + y_s, + n_features, + max_depth, + m_try, + sklearn_max, + seed, + samples_file, + features_file, + ) else: - return _build_subtree(sample, y_s, n_features, max_depth, n_classes, - m_try, sklearn_max, seed, samples_file) + return _build_subtree( + sample, + y_s, + n_features, + max_depth, + m_try, + sklearn_max, + seed, + samples_file, + ) @task(samples_file=FILE_IN, features_file=FILE_IN, returns=_Node) -def _build_subtree_using_features(sample, y_s, n_features, max_depth, - n_classes, m_try, sklearn_max, seed, - samples_file, features_file): +def _build_subtree_using_features( + sample, + y_s, + n_features, + max_depth, + m_try, + sklearn_max, + seed, + samples_file, + features_file, +): random_state = RandomState(seed) - return _compute_build_subtree(sample, y_s, n_features, max_depth, - n_classes, m_try, sklearn_max, random_state, - samples_file, features_file=features_file) + return _compute_build_subtree( + sample, + y_s, + n_features, + max_depth, + m_try, + sklearn_max, + random_state, + samples_file, + features_file=features_file, + ) @task(samples_file=FILE_IN, returns=_Node) -def _build_subtree(sample, y_s, n_features, max_depth, n_classes, m_try, - sklearn_max, seed, samples_file): +def _build_subtree( + sample, + y_s, + n_features, + max_depth, + m_try, + sklearn_max, + seed, + samples_file, +): random_state = RandomState(seed) - return _compute_build_subtree(sample, y_s, n_features, max_depth, - n_classes, m_try, sklearn_max, random_state, - samples_file) - - -def _compute_build_subtree(sample, y_s, n_features, max_depth, n_classes, - m_try, sklearn_max, random_state, samples_file, - features_file=None, use_sklearn=True): + return _compute_build_subtree( + sample, + y_s, + n_features, + max_depth, + m_try, + sklearn_max, + random_state, + samples_file, + ) + + +def _compute_build_subtree( + sample, + y_s, + n_features, + max_depth, + m_try, + sklearn_max, + random_state, + samples_file, + features_file=None, + use_sklearn=True, +): if not sample.size: return _Node() if features_file is not None: - mmap = np.load(features_file, mmap_mode='r', allow_pickle=False) + mmap = np.load(features_file, mmap_mode="r", allow_pickle=False) else: - mmap = np.load(samples_file, mmap_mode='r', allow_pickle=False).T + mmap = np.load(samples_file, mmap_mode="r", allow_pickle=False).T subtree = _Node() tree_traversal = [(subtree, sample, y_s, 0)] while tree_traversal: @@ -428,30 +471,41 @@ def _compute_build_subtree(sample, y_s, n_features, max_depth, n_classes, sklearn_max_depth = None else: sklearn_max_depth = max_depth - depth - dt = SklearnDTClassifier(max_features=m_try, - max_depth=sklearn_max_depth, - random_state=random_state) - unique = np.unique(sample, return_index=True, - return_counts=True) + dt = SklearnDTRegressor( + max_features=m_try, + max_depth=sklearn_max_depth, + random_state=random_state, + ) + unique = np.unique( + sample, return_index=True, return_counts=True + ) sample, new_indices, sample_weight = unique x = _get_sample_attributes(samples_file, sample) y_s = y_s[new_indices] dt.fit(x, y_s, sample_weight=sample_weight, check_input=False) node.content = _SkTreeWrapper(dt) else: - split = _compute_split(sample, n_features, y_s, n_classes, - m_try, mmap, random_state) + split = _compute_split( + sample, + n_features, + y_s, + m_try, + mmap, + random_state, + ) node_info, left_group, y_l, right_group, y_r = split node.content = node_info if isinstance(node_info, _InnerNodeInfo): node.left = _Node() node.right = _Node() - tree_traversal.append((node.right, right_group, y_r, - depth + 1)) - tree_traversal.append((node.left, left_group, y_l, - depth + 1)) + tree_traversal.append( + (node.right, right_group, y_r, depth + 1) + ) + tree_traversal.append( + (node.left, left_group, y_l, depth + 1) + ) else: - node.content = _compute_leaf_info(y_s, n_classes) + node.content = _compute_leaf_info(y_s) return subtree @@ -462,7 +516,7 @@ def _merge(*object_list): def _get_subtree_path(subtree_index, distr_depth): if distr_depth == 0: - return '' + return "" return bin(subtree_index)[2:].zfill(distr_depth) @@ -471,12 +525,12 @@ def _get_predicted_indices(samples, tree, nodes_info, path): for direction in path: node_info = nodes_info[tree.content] if isinstance(node_info, _LeafInfo): - if direction == '1': + if direction == "1": idx_mask[:] = 0 else: col = node_info.index value = node_info.value - if direction == '0': + if direction == "0": idx_mask[idx_mask] = samples[idx_mask, col] <= value tree = tree.left else: @@ -486,8 +540,9 @@ def _get_predicted_indices(samples, tree, nodes_info, path): @task(row_blocks={Type: COLLECTION_IN, Depth: 2}, returns=1) -def _predict_branch(row_blocks, tree, nodes_info, subtree_index, subtree, - distr_depth): +def _predict_branch( + row_blocks, tree, nodes_info, subtree_index, subtree, distr_depth +): samples = Array._merge_blocks(row_blocks) path = _get_subtree_path(subtree_index, distr_depth) indices_mask = _get_predicted_indices(samples, tree, nodes_info, path) @@ -495,16 +550,6 @@ def _predict_branch(row_blocks, tree, nodes_info, subtree_index, subtree, return indices_mask, prediction -@task(row_blocks={Type: COLLECTION_IN, Depth: 2}, returns=1) -def _predict_branch_proba(row_blocks, tree, nodes_info, subtree_index, subtree, - distr_depth, n_classes): - samples = Array._merge_blocks(row_blocks) - path = _get_subtree_path(subtree_index, distr_depth) - indices_mask = _get_predicted_indices(samples, tree, nodes_info, path) - prediction = subtree.predict_proba(samples[indices_mask], n_classes) - return indices_mask, prediction - - @task(returns=list) def _merge_branches(n_classes, *predictions): samples_len = len(predictions[0][0]) diff --git a/dislib/regression/rf/test_split.py b/dislib/regression/rf/test_split.py index 70922783..aa482b3c 100644 --- a/dislib/regression/rf/test_split.py +++ b/dislib/regression/rf/test_split.py @@ -3,15 +3,15 @@ import numpy as np -def gini_criteria_proxy(l_weight, l_length, r_weight, r_length, not_repeated): +def mse_criteria_proxy(l_weight, l_length, r_weight, r_length, not_repeated): """ - Maximizing the Gini gain is equivalent to minimizing this proxy function. + Maximizing the MSE gain is equivalent to minimizing this proxy function. """ return -(l_weight / l_length + r_weight / r_length) * not_repeated -def test_split(sample, y_s, feature, n_classes): +def test_split(sample, y_s, feature): size = y_s.shape[0] if size == 0: return float_info.max, np.float64(np.inf) @@ -21,28 +21,26 @@ def test_split(sample, y_s, feature, n_classes): y_sorted = y_s[sort_indices] f_sorted = f[sort_indices] + # Threshold value must not be that value of a sample not_repeated = np.empty(size, dtype=np.bool_) - not_repeated[0: size - 1] = (f_sorted[1:] != f_sorted[:-1]) + not_repeated[0 : size - 1] = f_sorted[1:] != f_sorted[:-1] not_repeated[size - 1] = True - l_freq = np.zeros((n_classes, size), dtype=np.int64) - l_freq[y_sorted, np.arange(size)] = 1 - - r_freq = np.zeros((n_classes, size), dtype=np.int64) - r_freq[:, 1:] = l_freq[:, :0:-1] - - l_weight = np.sum(np.square(np.cumsum(l_freq, axis=-1)), axis=0) - r_weight = np.sum(np.square(np.cumsum(r_freq, axis=-1)), axis=0)[::-1] + # Square of the sum of the y values of each branch + r_weight = np.zeros(size) + l_weight = np.square(np.cumsum(y_sorted, axis=-1)) + r_weight[:-1] = np.square(np.cumsum(y_sorted[::-1], axis=-1)[-2::-1]) + # Number of samples of each branch l_length = np.arange(1, size + 1, dtype=np.int32) r_length = np.arange(size - 1, -1, -1, dtype=np.int32) r_length[size - 1] = 1 # Avoid div by zero, the right score is 0 anyways - scores = gini_criteria_proxy(l_weight, l_length, r_weight, r_length, - not_repeated) + scores = mse_criteria_proxy( + l_weight, l_length, r_weight, r_length, not_repeated + ) min_index = size - np.argmin(scores[::-1]) - 1 - if min_index + 1 == size: b_value = np.float64(np.inf) else: