-
Notifications
You must be signed in to change notification settings - Fork 25
/
SRForest.py
29 lines (21 loc) · 967 Bytes
/
SRForest.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
from sklearn.datasets import load_diabetes
from sklearn.model_selection import cross_val_score
from evolutionary_forest.forest import EvolutionaryForestRegressor
hyper_params = [
{
},
]
est = EvolutionaryForestRegressor(max_height=8, normalize=True, select='AutomaticLexicase', boost_size=100,
basic_primitives='sin-cos', mutation_scheme='EDA-Terminal-PM',
semantic_diversity='GreedySelection-Resampling', initial_tree_size='2-6',
cross_pb=0.9, mutation_pb=0.1, gene_num=20, n_gen=100,
n_pop=200, base_learner='Fast-RidgeDT')
def complexity(est: EvolutionaryForestRegressor):
return est.complexity()
model = None
if __name__ == '__main__':
# Test the complexity function
X, y = load_diabetes(return_X_y=True)
print(cross_val_score(est, X, y, n_jobs=-1))
est.fit(X, y)
print(complexity(est))