From 9a4113d6ee223d26b0a8d9e04e4dc387af536556 Mon Sep 17 00:00:00 2001 From: vmirly Date: Sun, 24 Nov 2024 06:47:29 -0800 Subject: [PATCH] init CART notebook --- decision_tree/cart.ipynb | 163 +++++++++++++++++++++++++++++++++++++++ 1 file changed, 163 insertions(+) create mode 100644 decision_tree/cart.ipynb diff --git a/decision_tree/cart.ipynb b/decision_tree/cart.ipynb new file mode 100644 index 0000000..5295c71 --- /dev/null +++ b/decision_tree/cart.ipynb @@ -0,0 +1,163 @@ +{ + "cells": [ + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "## CART: Classification and Regresstion Tree\n", + "\n", + "### Gini impurity\n", + "\n", + " * with $c$ classes:\n", + "$$I_G(p)=1 - \\sum_{j=1}^c p_j^2$$" + ] + }, + { + "cell_type": "code", + "execution_count": 1, + "metadata": {}, + "outputs": [], + "source": [ + "import math\n", + "from collections import Counter\n", + "\n", + "import numpy as np" + ] + }, + { + "cell_type": "code", + "execution_count": 2, + "metadata": {}, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "0.48979591836734704\n" + ] + } + ], + "source": [ + "\n", + "def calc_gini(y):\n", + " m = len(y)\n", + " if m == 0:\n", + " return 0\n", + " counts = Counter(y)\n", + " probas = [c/m for c in counts.values()]\n", + " impurity = 1 - sum([p**2 for p in probas])\n", + " return impurity\n", + "\n", + "# testing the function\n", + "y = [1, 1, 1, 1, 0, 0, 0]\n", + "print(calc_gini(y))" + ] + }, + { + "cell_type": "code", + "execution_count": 3, + "metadata": {}, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "0.9852281360342516\n" + ] + } + ], + "source": [ + "def entropy(y):\n", + " m = len(y)\n", + " counts = Counter(y)\n", + " probas = [c/m for c in counts.values()]\n", + " return -sum([p * math.log(p, 2) for p in probas])\n", + "\n", + "# testing the function\n", + "y = [1, 1, 1, 1, 0, 0, 0]\n", + "print(entropy(y))" + ] + }, + { + "cell_type": "code", + "execution_count": 4, + "metadata": {}, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "[0 3] [1 2]\n" + ] + } + ], + "source": [ + "def split_dataset(X, feature_index, threshold):\n", + " left_indices = np.where(X[:, feature_index] < threshold)[0]\n", + " right_indices = np.where(X[:, feature_index] >= threshold)[0]\n", + "\n", + " return left_indices, right_indices\n", + "\n", + "# testing the function\n", + "X = np.array([[-1, 2], [3, 4], [3, 6], [-2, 8]])\n", + "feature_index = 0\n", + "threshold = 0\n", + "left, right = split_dataset(X, feature_index, threshold)\n", + "print(left, right)" + ] + }, + { + "cell_type": "code", + "execution_count": 5, + "metadata": {}, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "0.125\n", + "0.0\n" + ] + } + ], + "source": [ + "def gini_split(X, y, feature_index, threshold):\n", + " left_indices, right_indices = split_dataset(X, feature_index, threshold)\n", + " left_y, right_y = y[left_indices], y[right_indices]\n", + " m = len(y)\n", + " w_left, w_right = len(left_y) / m, len(right_y) / m\n", + " gini_left = calc_gini(left_y)\n", + " gini_right = calc_gini(right_y)\n", + " gini = (w_left/m) * gini_left + (w_right/m) * gini_right\n", + " return gini\n", + "\n", + "# testing the function\n", + "y = np.array([1, 1, 0, 0])\n", + "print(gini_split(X, y, feature_index, threshold))\n", + "y = np.array([0, 1, 1, 0])\n", + "print(gini_split(X, y, feature_index, threshold))" + ] + } + ], + "metadata": { + "kernelspec": { + "display_name": "py310", + "language": "python", + "name": "python3" + }, + "language_info": { + "codemirror_mode": { + "name": "ipython", + "version": 3 + }, + "file_extension": ".py", + "mimetype": "text/x-python", + "name": "python", + "nbconvert_exporter": "python", + "pygments_lexer": "ipython3", + "version": "3.10.15" + } + }, + "nbformat": 4, + "nbformat_minor": 2 +}