Skip to content

Commit 14e26b4

Browse files
Refactor tree_depths function to use model.get_depth() in CW2 (2).py
1 parent ed6eee5 commit 14e26b4

File tree

1 file changed

+19
-5
lines changed

1 file changed

+19
-5
lines changed

CW2 (2).py

Lines changed: 19 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -151,14 +151,30 @@ def optimal_ccp_alpha(x_train, y_train, x_test, y_test):
151151
def tree_depths(model):
152152
depth=None
153153
# Get the depth of the unpruned tree
154-
# Insert your code here for task 9
154+
depth = model.get_depth()
155155
return depth
156156

157157
# Task 10 [10 marks]: Feature importance
158158
def important_feature(x_train, y_train,header_list):
159159
best_feature=None
160160
# Train decision tree model and increase Cost Complexity Parameter until the depth reaches 1
161-
# Insert your code here for task 10
161+
tree = DecisionTreeClassifier(random_state=0)
162+
tree.fit(x_train, y_train)
163+
# Calculate the cost complexity pruning path
164+
path = tree.cost_complexity_pruning_path(x_train, y_train)
165+
cpp_alphas = path.ccp_alphas
166+
167+
# Iterate over the ccp_alphas to find the best feature
168+
for ccp_alpha in cpp_alphas:
169+
# Train the decision tree with the current ccp_alpha
170+
tree = DecisionTreeClassifier(random_state=0, ccp_alpha=ccp_alpha)
171+
tree.fit(x_train, y_train)
172+
# Check if the tree length is 1
173+
if tree.get_depth() == 1:
174+
break
175+
176+
# Get the feature importance
177+
best_feature = header_list[np.argmax(tree.feature_importances_)]
162178
return best_feature
163179

164180

@@ -249,9 +265,7 @@ def important_feature(x_train, y_train,header_list):
249265
- Line 94 is inspired from https://www.geeksforgeeks.org/learning-model-building-scikit-learn-python-machine-learning-library/
250266
- Line 100 is inspired from https://www.askpython.com/python/examples/python-predict-function
251267
- Line 106-109 is inspired from https://www.linkedin.com/pulse/basics-decision-tree-python-omkar-sutar#:~:text=To%20calculate%20the%20accuracy%20score,from%20the%20scikit%2Dlearn%20library.&text=In%20this%20code%2C%20y_test%20is,by%20the%20decision%20tree%20model.
252-
253-
254-
'''
268+
- Line 154 is inspired from https://scikit-learn.org/stable/modules/generated/sklearn.tree.DecisionTreeClassifier.html#sklearn.tree.DecisionTreeClassifier.get_depth'''
255269

256270

257271

0 commit comments

Comments
 (0)