Skip to content

Commit

Permalink
defined the class
Browse files Browse the repository at this point in the history
  • Loading branch information
vmirly committed Nov 15, 2024
1 parent eef8c68 commit d345659
Showing 1 changed file with 51 additions and 1 deletion.
52 changes: 51 additions & 1 deletion decision_tree/decision_tree.ipynb
Original file line number Diff line number Diff line change
Expand Up @@ -143,7 +143,57 @@
"execution_count": null,
"metadata": {},
"outputs": [],
"source": []
"source": [
"class DecisionTree():\n",
" def __init__(\n",
" self,\n",
" max_depth=None,\n",
" min_samples_split=2,\n",
" criterion='gini'\n",
" ):\n",
" self.max_depth = max_depth\n",
" self.min_samples_split = min_samples_split\n",
" self.criterion = criterion\n",
"\n",
" def fit(self, X, y):\n",
" self.n_classes = len(np.unique(y))\n",
" self.n_features = X.shape[1]\n",
" self.tree = self._build_tree(X, y)\n",
"\n",
" def _build_tree(self, X, y, depth=0):\n",
" n_samples, n_features = X.shape\n",
"\n",
" if depth >= self.max_depth or n_samples < self.min_samples_split:\n",
" return {\n",
" \"leaf\": True,\n",
" 'value': self._most_common_label(y)\n",
" }\n",
" \n",
" if self.criterion == 'gini':\n",
" impurity_func = calc_gini\n",
" else:\n",
" impurity_func = entropy\n",
"\n",
" best_impurity = float('inf')\n",
"\n",
" best_feature, best_threshold = self._find_best_split(X, y, n_features)\n",
" if best_feature is None:\n",
" return {\n",
" \"leaf\": True,\n",
" 'value': self._most_common_label(y)\n",
" }\n",
" left_indices, right_indices = split_dataset(X, best_feature, best_threshold)\n",
"\n",
" return {\n",
" \"leaf\": False,\n",
" \"feature_index\": best_feature,\n",
" \"threshold\": best_threshold,\n",
" \"left\": self._build_tree(X[left_indices], y[left_indices], depth + 1),\n",
" \"right\": self._build_tree(X[right_indices], y[right_indices], depth + 1)\n",
" }\n",
" \n",
" \n"
]
}
],
"metadata": {
Expand Down

0 comments on commit d345659

Please sign in to comment.