Skip to content

Commit

Permalink
init CART notebook
Browse files Browse the repository at this point in the history
  • Loading branch information
vmirly committed Nov 24, 2024
1 parent 8ef024c commit 9a4113d
Showing 1 changed file with 163 additions and 0 deletions.
163 changes: 163 additions & 0 deletions decision_tree/cart.ipynb
Original file line number Diff line number Diff line change
@@ -0,0 +1,163 @@
{
"cells": [
{
"cell_type": "markdown",
"metadata": {},
"source": [
"## CART: Classification and Regresstion Tree\n",
"\n",
"### Gini impurity\n",
"\n",
" * with $c$ classes:\n",
"$$I_G(p)=1 - \\sum_{j=1}^c p_j^2$$"
]
},
{
"cell_type": "code",
"execution_count": 1,
"metadata": {},
"outputs": [],
"source": [
"import math\n",
"from collections import Counter\n",
"\n",
"import numpy as np"
]
},
{
"cell_type": "code",
"execution_count": 2,
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"0.48979591836734704\n"
]
}
],
"source": [
"\n",
"def calc_gini(y):\n",
" m = len(y)\n",
" if m == 0:\n",
" return 0\n",
" counts = Counter(y)\n",
" probas = [c/m for c in counts.values()]\n",
" impurity = 1 - sum([p**2 for p in probas])\n",
" return impurity\n",
"\n",
"# testing the function\n",
"y = [1, 1, 1, 1, 0, 0, 0]\n",
"print(calc_gini(y))"
]
},
{
"cell_type": "code",
"execution_count": 3,
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"0.9852281360342516\n"
]
}
],
"source": [
"def entropy(y):\n",
" m = len(y)\n",
" counts = Counter(y)\n",
" probas = [c/m for c in counts.values()]\n",
" return -sum([p * math.log(p, 2) for p in probas])\n",
"\n",
"# testing the function\n",
"y = [1, 1, 1, 1, 0, 0, 0]\n",
"print(entropy(y))"
]
},
{
"cell_type": "code",
"execution_count": 4,
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"[0 3] [1 2]\n"
]
}
],
"source": [
"def split_dataset(X, feature_index, threshold):\n",
" left_indices = np.where(X[:, feature_index] < threshold)[0]\n",
" right_indices = np.where(X[:, feature_index] >= threshold)[0]\n",
"\n",
" return left_indices, right_indices\n",
"\n",
"# testing the function\n",
"X = np.array([[-1, 2], [3, 4], [3, 6], [-2, 8]])\n",
"feature_index = 0\n",
"threshold = 0\n",
"left, right = split_dataset(X, feature_index, threshold)\n",
"print(left, right)"
]
},
{
"cell_type": "code",
"execution_count": 5,
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"0.125\n",
"0.0\n"
]
}
],
"source": [
"def gini_split(X, y, feature_index, threshold):\n",
" left_indices, right_indices = split_dataset(X, feature_index, threshold)\n",
" left_y, right_y = y[left_indices], y[right_indices]\n",
" m = len(y)\n",
" w_left, w_right = len(left_y) / m, len(right_y) / m\n",
" gini_left = calc_gini(left_y)\n",
" gini_right = calc_gini(right_y)\n",
" gini = (w_left/m) * gini_left + (w_right/m) * gini_right\n",
" return gini\n",
"\n",
"# testing the function\n",
"y = np.array([1, 1, 0, 0])\n",
"print(gini_split(X, y, feature_index, threshold))\n",
"y = np.array([0, 1, 1, 0])\n",
"print(gini_split(X, y, feature_index, threshold))"
]
}
],
"metadata": {
"kernelspec": {
"display_name": "py310",
"language": "python",
"name": "python3"
},
"language_info": {
"codemirror_mode": {
"name": "ipython",
"version": 3
},
"file_extension": ".py",
"mimetype": "text/x-python",
"name": "python",
"nbconvert_exporter": "python",
"pygments_lexer": "ipython3",
"version": "3.10.15"
}
},
"nbformat": 4,
"nbformat_minor": 2
}

0 comments on commit 9a4113d

Please sign in to comment.