-
Notifications
You must be signed in to change notification settings - Fork 11
/
Copy pathmy_DT.py
44 lines (35 loc) · 1.63 KB
/
my_DT.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
import pandas as pd
import numpy as np
from collections import Counter
class my_DT:
def __init__(self, criterion="gini", max_depth=8, min_impurity_decrease=0, min_samples_split=2):
# criterion = {"gini", "entropy"},
# Stop training if depth = max_depth. Depth of a binary tree: the max number of edges from the root node to a leaf node
# Only split node if impurity decrease >= min_impurity_decrease after the split
# Weighted impurity decrease: impurity - (N_t_R / N * right_impurity + N_t_L / N * left_impurity)
# Only split node with >= min_samples_split samples
self.criterion = criterion
self.max_depth = int(max_depth)
self.min_impurity_decrease = min_impurity_decrease
self.min_samples_split = int(min_samples_split)
def fit(self, X, y):
# X: pd.DataFrame, independent variables, float
# y: list, np.array or pd.Series, dependent variables, int or str
self.classes_ = list(set(list(y)))
# write your code below
return
def predict(self, X):
# X: pd.DataFrame, independent variables, float
# return predictions: list
# write your code below
return predictions
def predict_proba(self, X):
# X: pd.DataFrame, independent variables, float
# Eample:
# self.classes_ = {"2", "1"}
# the reached node for the test data point has {"1":2, "2":1}
# then the prob for that data point is {"2": 1/3, "1": 2/3}
# return probs = pd.DataFrame(list of prob, columns = self.classes_)
# write your code below
##################
return probs