Skip to content

Commit

Permalink
functions to split and calc gini_split
Browse files Browse the repository at this point in the history
  • Loading branch information
vmirly committed Nov 15, 2024
1 parent 67d1ce8 commit eef8c68
Showing 1 changed file with 66 additions and 4 deletions.
70 changes: 66 additions & 4 deletions decision_tree/decision_tree.ipynb
Original file line number Diff line number Diff line change
Expand Up @@ -14,17 +14,19 @@
},
{
"cell_type": "code",
"execution_count": 7,
"execution_count": 1,
"metadata": {},
"outputs": [],
"source": [
"import math\n",
"from collections import Counter"
"from collections import Counter\n",
"\n",
"import numpy as np"
]
},
{
"cell_type": "code",
"execution_count": null,
"execution_count": 2,
"metadata": {},
"outputs": [
{
Expand Down Expand Up @@ -53,7 +55,7 @@
},
{
"cell_type": "code",
"execution_count": 8,
"execution_count": 3,
"metadata": {},
"outputs": [
{
Expand All @@ -76,6 +78,66 @@
"print(entropy(y))"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"[0 3] [1 2]\n"
]
}
],
"source": [
"def split_dataset(X, feature_index, threshold):\n",
" left_indices = np.where(X[:, feature_index] < threshold)[0]\n",
" right_indices = np.where(X[:, feature_index] >= threshold)[0]\n",
"\n",
" return left_indices, right_indices\n",
"\n",
"# testing the function\n",
"X = np.array([[-1, 2], [3, 4], [3, 6], [-2, 8]])\n",
"feature_index = 0\n",
"threshold = 0\n",
"left, right = split_dataset(X, feature_index, threshold)\n",
"print(left, right)"
]
},
{
"cell_type": "code",
"execution_count": 14,
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"0.125\n",
"0.0\n"
]
}
],
"source": [
"def gini_split(X, y, feature_index, threshold):\n",
" left_indices, right_indices = split_dataset(X, feature_index, threshold)\n",
" left_y, right_y = y[left_indices], y[right_indices]\n",
" m = len(y)\n",
" w_left, w_right = len(left_y) / m, len(right_y) / m\n",
" gini_left = calc_gini(left_y)\n",
" gini_right = calc_gini(right_y)\n",
" gini = (w_left/m) * gini_left + (w_right/m) * gini_right\n",
" return gini\n",
"\n",
"# testing the function\n",
"y = np.array([1, 1, 0, 0])\n",
"print(gini_split(X, y, feature_index, threshold))\n",
"y = np.array([0, 1, 1, 0])\n",
"print(gini_split(X, y, feature_index, threshold))"
]
},
{
"cell_type": "code",
"execution_count": null,
Expand Down

0 comments on commit eef8c68

Please sign in to comment.