diff --git a/LICENSE b/LICENSE index c6e9aa05..329db5e3 100644 --- a/LICENSE +++ b/LICENSE @@ -1,6 +1,6 @@ MIT License -Copyright (c) 2022 Andrej +Copyright (c) 2022 Andrej Karpathy Permission is hereby granted, free of charge, to any person obtaining a copy of this software and associated documentation files (the "Software"), to deal diff --git a/README.md b/README.md index d5b84af7..90a30bca 100644 --- a/README.md +++ b/README.md @@ -1,2 +1,35 @@ -# nn-zero-to-hero -Neural Networks: Zero to Hero + +## Neural Networks: Zero to Hero + +A course of neural networks that starts all thew way at the basics. The course is a series of YouTube videos where we code and train neural networks together. The Jupyter notebooks we build in the videos are then captured here inside the [lectures](lectures/) directory. Every lecture also has a set of exercises included in the video description. (This may grow into something more respectable). + +--- + +**Lecture 1: The spelled-out intro to neural networks and backpropagation: building micrograd** + +Backpropagation and training of neural networks. Assumes basic knowledge of Python and a vague recollection of calculus from high school. + +- [YouTube video lecture](https://www.youtube.com/watch?v=VMj-3S1tku0) +- [Jupyter notebook files](lectures/micrograd) +- [micrograd Github repo](https://github.com/karpathy/micrograd) + +--- + +**Lecture 2: The spelled-out intro to language modeling: building makemore** + +We implement a bigram character-level language model, which we will further complexify in followup videos into a modern Transformer language model, like GPT. In this video, the focus is on (1) introducing torch.Tensor and its subtleties and use in efficiently evaluating neural networks and (2) the overall framework of language modeling that includes model training, sampling, and the evaluation of a loss (e.g. the negative log likelihood for classification). + +- [YouTube video lecture](https://www.youtube.com/watch?v=PaCmpygFfXo) +- [Jupyter notebook files](lectures/makemore/makemore_part1_bigrams.ipynb) +- [makemore Github repo](https://github.com/karpathy/makemore) + +--- + +(ongoing...) + +--- + + +**License** + +MIT \ No newline at end of file diff --git a/lectures/makemore/makemore_part1_bigrams.ipynb b/lectures/makemore/makemore_part1_bigrams.ipynb new file mode 100644 index 00000000..fbe58a86 --- /dev/null +++ b/lectures/makemore/makemore_part1_bigrams.ipynb @@ -0,0 +1,1913 @@ +{ + "cells": [ + { + "cell_type": "code", + "execution_count": 4, + "metadata": {}, + "outputs": [], + "source": [ + "words = open('names.txt', 'r').read().splitlines()" + ] + }, + { + "cell_type": "code", + "execution_count": 5, + "metadata": {}, + "outputs": [ + { + "data": { + "text/plain": [ + "['emma',\n", + " 'olivia',\n", + " 'ava',\n", + " 'isabella',\n", + " 'sophia',\n", + " 'charlotte',\n", + " 'mia',\n", + " 'amelia',\n", + " 'harper',\n", + " 'evelyn']" + ] + }, + "execution_count": 5, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "words[:10]" + ] + }, + { + "cell_type": "code", + "execution_count": 6, + "metadata": {}, + "outputs": [ + { + "data": { + "text/plain": [ + "32033" + ] + }, + "execution_count": 6, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "len(words)" + ] + }, + { + "cell_type": "code", + "execution_count": 7, + "metadata": {}, + "outputs": [ + { + "data": { + "text/plain": [ + "2" + ] + }, + "execution_count": 7, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "min(len(w) for w in words)" + ] + }, + { + "cell_type": "code", + "execution_count": 8, + "metadata": {}, + "outputs": [ + { + "data": { + "text/plain": [ + "15" + ] + }, + "execution_count": 8, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "max(len(w) for w in words)" + ] + }, + { + "cell_type": "code", + "execution_count": 9, + "metadata": {}, + "outputs": [], + "source": [ + "b = {}\n", + "for w in words:\n", + " chs = [''] + list(w) + ['']\n", + " for ch1, ch2 in zip(chs, chs[1:]):\n", + " bigram = (ch1, ch2)\n", + " b[bigram] = b.get(bigram, 0) + 1" + ] + }, + { + "cell_type": "code", + "execution_count": 10, + "metadata": {}, + "outputs": [ + { + "data": { + "text/plain": [ + "[(('n', ''), 6763),\n", + " (('a', ''), 6640),\n", + " (('a', 'n'), 5438),\n", + " (('', 'a'), 4410),\n", + " (('e', ''), 3983),\n", + " (('a', 'r'), 3264),\n", + " (('e', 'l'), 3248),\n", + " (('r', 'i'), 3033),\n", + " (('n', 'a'), 2977),\n", + " (('', 'k'), 2963),\n", + " (('l', 'e'), 2921),\n", + " (('e', 'n'), 2675),\n", + " (('l', 'a'), 2623),\n", + " (('m', 'a'), 2590),\n", + " (('', 'm'), 2538),\n", + " (('a', 'l'), 2528),\n", + " (('i', ''), 2489),\n", + " (('l', 'i'), 2480),\n", + " (('i', 'a'), 2445),\n", + " (('', 'j'), 2422),\n", + " (('o', 'n'), 2411),\n", + " (('h', ''), 2409),\n", + " (('r', 'a'), 2356),\n", + " (('a', 'h'), 2332),\n", + " (('h', 'a'), 2244),\n", + " (('y', 'a'), 2143),\n", + " (('i', 'n'), 2126),\n", + " (('', 's'), 2055),\n", + " (('a', 'y'), 2050),\n", + " (('y', ''), 2007),\n", + " (('e', 'r'), 1958),\n", + " (('n', 'n'), 1906),\n", + " (('y', 'n'), 1826),\n", + " (('k', 'a'), 1731),\n", + " (('n', 'i'), 1725),\n", + " (('r', 'e'), 1697),\n", + " (('', 'd'), 1690),\n", + " (('i', 'e'), 1653),\n", + " (('a', 'i'), 1650),\n", + " (('', 'r'), 1639),\n", + " (('a', 'm'), 1634),\n", + " (('l', 'y'), 1588),\n", + " (('', 'l'), 1572),\n", + " (('', 'c'), 1542),\n", + " (('', 'e'), 1531),\n", + " (('j', 'a'), 1473),\n", + " (('r', ''), 1377),\n", + " (('n', 'e'), 1359),\n", + " (('l', 'l'), 1345),\n", + " (('i', 'l'), 1345),\n", + " (('i', 's'), 1316),\n", + " (('l', ''), 1314),\n", + " (('', 't'), 1308),\n", + " (('', 'b'), 1306),\n", + " (('d', 'a'), 1303),\n", + " (('s', 'h'), 1285),\n", + " (('d', 'e'), 1283),\n", + " (('e', 'e'), 1271),\n", + " (('m', 'i'), 1256),\n", + " (('s', 'a'), 1201),\n", + " (('s', ''), 1169),\n", + " (('', 'n'), 1146),\n", + " (('a', 's'), 1118),\n", + " (('y', 'l'), 1104),\n", + " (('e', 'y'), 1070),\n", + " (('o', 'r'), 1059),\n", + " (('a', 'd'), 1042),\n", + " (('t', 'a'), 1027),\n", + " (('', 'z'), 929),\n", + " (('v', 'i'), 911),\n", + " (('k', 'e'), 895),\n", + " (('s', 'e'), 884),\n", + " (('', 'h'), 874),\n", + " (('r', 'o'), 869),\n", + " (('e', 's'), 861),\n", + " (('z', 'a'), 860),\n", + " (('o', ''), 855),\n", + " (('i', 'r'), 849),\n", + " (('b', 'r'), 842),\n", + " (('a', 'v'), 834),\n", + " (('m', 'e'), 818),\n", + " (('e', 'i'), 818),\n", + " (('c', 'a'), 815),\n", + " (('i', 'y'), 779),\n", + " (('r', 'y'), 773),\n", + " (('e', 'm'), 769),\n", + " (('s', 't'), 765),\n", + " (('h', 'i'), 729),\n", + " (('t', 'e'), 716),\n", + " (('n', 'd'), 704),\n", + " (('l', 'o'), 692),\n", + " (('a', 'e'), 692),\n", + " (('a', 't'), 687),\n", + " (('s', 'i'), 684),\n", + " (('e', 'a'), 679),\n", + " (('d', 'i'), 674),\n", + " (('h', 'e'), 674),\n", + " (('', 'g'), 669),\n", + " (('t', 'o'), 667),\n", + " (('c', 'h'), 664),\n", + " (('b', 'e'), 655),\n", + " (('t', 'h'), 647),\n", + " (('v', 'a'), 642),\n", + " (('o', 'l'), 619),\n", + " (('', 'i'), 591),\n", + " (('i', 'o'), 588),\n", + " (('e', 't'), 580),\n", + " (('v', 'e'), 568),\n", + " (('a', 'k'), 568),\n", + " (('a', 'a'), 556),\n", + " (('c', 'e'), 551),\n", + " (('a', 'b'), 541),\n", + " (('i', 't'), 541),\n", + " (('', 'y'), 535),\n", + " (('t', 'i'), 532),\n", + " (('s', 'o'), 531),\n", + " (('m', ''), 516),\n", + " (('d', ''), 516),\n", + " (('', 'p'), 515),\n", + " (('i', 'c'), 509),\n", + " (('k', 'i'), 509),\n", + " (('o', 's'), 504),\n", + " (('n', 'o'), 496),\n", + " (('t', ''), 483),\n", + " (('j', 'o'), 479),\n", + " (('u', 's'), 474),\n", + " (('a', 'c'), 470),\n", + " (('n', 'y'), 465),\n", + " (('e', 'v'), 463),\n", + " (('s', 's'), 461),\n", + " (('m', 'o'), 452),\n", + " (('i', 'k'), 445),\n", + " (('n', 't'), 443),\n", + " (('i', 'd'), 440),\n", + " (('j', 'e'), 440),\n", + " (('a', 'z'), 435),\n", + " (('i', 'g'), 428),\n", + " (('i', 'm'), 427),\n", + " (('r', 'r'), 425),\n", + " (('d', 'r'), 424),\n", + " (('', 'f'), 417),\n", + " (('u', 'r'), 414),\n", + " (('r', 'l'), 413),\n", + " (('y', 's'), 401),\n", + " (('', 'o'), 394),\n", + " (('e', 'd'), 384),\n", + " (('a', 'u'), 381),\n", + " (('c', 'o'), 380),\n", + " (('k', 'y'), 379),\n", + " (('d', 'o'), 378),\n", + " (('', 'v'), 376),\n", + " (('t', 't'), 374),\n", + " (('z', 'e'), 373),\n", + " (('z', 'i'), 364),\n", + " (('k', ''), 363),\n", + " (('g', 'h'), 360),\n", + " (('t', 'r'), 352),\n", + " (('k', 'o'), 344),\n", + " (('t', 'y'), 341),\n", + " (('g', 'e'), 334),\n", + " (('g', 'a'), 330),\n", + " (('l', 'u'), 324),\n", + " (('b', 'a'), 321),\n", + " (('d', 'y'), 317),\n", + " (('c', 'k'), 316),\n", + " (('', 'w'), 307),\n", + " (('k', 'h'), 307),\n", + " (('u', 'l'), 301),\n", + " (('y', 'e'), 301),\n", + " (('y', 'r'), 291),\n", + " (('m', 'y'), 287),\n", + " (('h', 'o'), 287),\n", + " (('w', 'a'), 280),\n", + " (('s', 'l'), 279),\n", + " (('n', 's'), 278),\n", + " (('i', 'z'), 277),\n", + " (('u', 'n'), 275),\n", + " (('o', 'u'), 275),\n", + " (('n', 'g'), 273),\n", + " (('y', 'd'), 272),\n", + " (('c', 'i'), 271),\n", + " (('y', 'o'), 271),\n", + " (('i', 'v'), 269),\n", + " (('e', 'o'), 269),\n", + " (('o', 'm'), 261),\n", + " (('r', 'u'), 252),\n", + " (('f', 'a'), 242),\n", + " (('b', 'i'), 217),\n", + " (('s', 'y'), 215),\n", + " (('n', 'c'), 213),\n", + " (('h', 'y'), 213),\n", + " (('p', 'a'), 209),\n", + " (('r', 't'), 208),\n", + " (('q', 'u'), 206),\n", + " (('p', 'h'), 204),\n", + " (('h', 'r'), 204),\n", + " (('j', 'u'), 202),\n", + " (('g', 'r'), 201),\n", + " (('p', 'e'), 197),\n", + " (('n', 'l'), 195),\n", + " (('y', 'i'), 192),\n", + " (('g', 'i'), 190),\n", + " (('o', 'd'), 190),\n", + " (('r', 's'), 190),\n", + " (('r', 'd'), 187),\n", + " (('h', 'l'), 185),\n", + " (('s', 'u'), 185),\n", + " (('a', 'x'), 182),\n", + " (('e', 'z'), 181),\n", + " (('e', 'k'), 178),\n", + " (('o', 'v'), 176),\n", + " (('a', 'j'), 175),\n", + " (('o', 'h'), 171),\n", + " (('u', 'e'), 169),\n", + " (('m', 'm'), 168),\n", + " (('a', 'g'), 168),\n", + " (('h', 'u'), 166),\n", + " (('x', ''), 164),\n", + " (('u', 'a'), 163),\n", + " (('r', 'm'), 162),\n", + " (('a', 'w'), 161),\n", + " (('f', 'i'), 160),\n", + " (('z', ''), 160),\n", + " (('u', ''), 155),\n", + " (('u', 'm'), 154),\n", + " (('e', 'c'), 153),\n", + " (('v', 'o'), 153),\n", + " (('e', 'h'), 152),\n", + " (('p', 'r'), 151),\n", + " (('d', 'd'), 149),\n", + " (('o', 'a'), 149),\n", + " (('w', 'e'), 149),\n", + " (('w', 'i'), 148),\n", + " (('y', 'm'), 148),\n", + " (('z', 'y'), 147),\n", + " (('n', 'z'), 145),\n", + " (('y', 'u'), 141),\n", + " (('r', 'n'), 140),\n", + " (('o', 'b'), 140),\n", + " (('k', 'l'), 139),\n", + " (('m', 'u'), 139),\n", + " (('l', 'd'), 138),\n", + " (('h', 'n'), 138),\n", + " (('u', 'd'), 136),\n", + " (('', 'x'), 134),\n", + " (('t', 'l'), 134),\n", + " (('a', 'f'), 134),\n", + " (('o', 'e'), 132),\n", + " (('e', 'x'), 132),\n", + " (('e', 'g'), 125),\n", + " (('f', 'e'), 123),\n", + " (('z', 'l'), 123),\n", + " (('u', 'i'), 121),\n", + " (('v', 'y'), 121),\n", + " (('e', 'b'), 121),\n", + " (('r', 'h'), 121),\n", + " (('j', 'i'), 119),\n", + " (('o', 't'), 118),\n", + " (('d', 'h'), 118),\n", + " (('h', 'm'), 117),\n", + " (('c', 'l'), 116),\n", + " (('o', 'o'), 115),\n", + " (('y', 'c'), 115),\n", + " (('o', 'w'), 114),\n", + " (('o', 'c'), 114),\n", + " (('f', 'r'), 114),\n", + " (('b', ''), 114),\n", + " (('m', 'b'), 112),\n", + " (('z', 'o'), 110),\n", + " (('i', 'b'), 110),\n", + " (('i', 'u'), 109),\n", + " (('k', 'r'), 109),\n", + " (('g', ''), 108),\n", + " (('y', 'v'), 106),\n", + " (('t', 'z'), 105),\n", + " (('b', 'o'), 105),\n", + " (('c', 'y'), 104),\n", + " (('y', 't'), 104),\n", + " (('u', 'b'), 103),\n", + " (('u', 'c'), 103),\n", + " (('x', 'a'), 103),\n", + " (('b', 'l'), 103),\n", + " (('o', 'y'), 103),\n", + " (('x', 'i'), 102),\n", + " (('i', 'f'), 101),\n", + " (('r', 'c'), 99),\n", + " (('c', ''), 97),\n", + " (('m', 'r'), 97),\n", + " (('n', 'u'), 96),\n", + " (('o', 'p'), 95),\n", + " (('i', 'h'), 95),\n", + " (('k', 's'), 95),\n", + " (('l', 's'), 94),\n", + " (('u', 'k'), 93),\n", + " (('', 'q'), 92),\n", + " (('d', 'u'), 92),\n", + " (('s', 'm'), 90),\n", + " (('r', 'k'), 90),\n", + " (('i', 'x'), 89),\n", + " (('v', ''), 88),\n", + " (('y', 'k'), 86),\n", + " (('u', 'w'), 86),\n", + " (('g', 'u'), 85),\n", + " (('b', 'y'), 83),\n", + " (('e', 'p'), 83),\n", + " (('g', 'o'), 83),\n", + " (('s', 'k'), 82),\n", + " (('u', 't'), 82),\n", + " (('a', 'p'), 82),\n", + " (('e', 'f'), 82),\n", + " (('i', 'i'), 82),\n", + " (('r', 'v'), 80),\n", + " (('f', ''), 80),\n", + " (('t', 'u'), 78),\n", + " (('y', 'z'), 78),\n", + " (('', 'u'), 78),\n", + " (('l', 't'), 77),\n", + " (('r', 'g'), 76),\n", + " (('c', 'r'), 76),\n", + " (('i', 'j'), 76),\n", + " (('w', 'y'), 73),\n", + " (('z', 'u'), 73),\n", + " (('l', 'v'), 72),\n", + " (('h', 't'), 71),\n", + " (('j', ''), 71),\n", + " (('x', 't'), 70),\n", + " (('o', 'i'), 69),\n", + " (('e', 'u'), 69),\n", + " (('o', 'k'), 68),\n", + " (('b', 'd'), 65),\n", + " (('a', 'o'), 63),\n", + " (('p', 'i'), 61),\n", + " (('s', 'c'), 60),\n", + " (('d', 'l'), 60),\n", + " (('l', 'm'), 60),\n", + " (('a', 'q'), 60),\n", + " (('f', 'o'), 60),\n", + " (('p', 'o'), 59),\n", + " (('n', 'k'), 58),\n", + " (('w', 'n'), 58),\n", + " (('u', 'h'), 58),\n", + " (('e', 'j'), 55),\n", + " (('n', 'v'), 55),\n", + " (('s', 'r'), 55),\n", + " (('o', 'z'), 54),\n", + " (('i', 'p'), 53),\n", + " (('l', 'b'), 52),\n", + " (('i', 'q'), 52),\n", + " (('w', ''), 51),\n", + " (('m', 'c'), 51),\n", + " (('s', 'p'), 51),\n", + " (('e', 'w'), 50),\n", + " (('k', 'u'), 50),\n", + " (('v', 'r'), 48),\n", + " (('u', 'g'), 47),\n", + " (('o', 'x'), 45),\n", + " (('u', 'z'), 45),\n", + " (('z', 'z'), 45),\n", + " (('j', 'h'), 45),\n", + " (('b', 'u'), 45),\n", + " (('o', 'g'), 44),\n", + " (('n', 'r'), 44),\n", + " (('f', 'f'), 44),\n", + " (('n', 'j'), 44),\n", + " (('z', 'h'), 43),\n", + " (('c', 'c'), 42),\n", + " (('r', 'b'), 41),\n", + " (('x', 'o'), 41),\n", + " (('b', 'h'), 41),\n", + " (('p', 'p'), 39),\n", + " (('x', 'l'), 39),\n", + " (('h', 'v'), 39),\n", + " (('b', 'b'), 38),\n", + " (('m', 'p'), 38),\n", + " (('x', 'x'), 38),\n", + " (('u', 'v'), 37),\n", + " (('x', 'e'), 36),\n", + " (('w', 'o'), 36),\n", + " (('c', 't'), 35),\n", + " (('z', 'm'), 35),\n", + " (('t', 's'), 35),\n", + " (('m', 's'), 35),\n", + " (('c', 'u'), 35),\n", + " (('o', 'f'), 34),\n", + " (('u', 'x'), 34),\n", + " (('k', 'w'), 34),\n", + " (('p', ''), 33),\n", + " (('g', 'l'), 32),\n", + " (('z', 'r'), 32),\n", + " (('d', 'n'), 31),\n", + " (('g', 't'), 31),\n", + " (('g', 'y'), 31),\n", + " (('h', 's'), 31),\n", + " (('x', 's'), 31),\n", + " (('g', 's'), 30),\n", + " (('x', 'y'), 30),\n", + " (('y', 'g'), 30),\n", + " (('d', 'm'), 30),\n", + " (('d', 's'), 29),\n", + " (('h', 'k'), 29),\n", + " (('y', 'x'), 28),\n", + " (('q', ''), 28),\n", + " (('g', 'n'), 27),\n", + " (('y', 'b'), 27),\n", + " (('g', 'w'), 26),\n", + " (('n', 'h'), 26),\n", + " (('k', 'n'), 26),\n", + " (('g', 'g'), 25),\n", + " (('d', 'g'), 25),\n", + " (('l', 'c'), 25),\n", + " (('r', 'j'), 25),\n", + " (('w', 'u'), 25),\n", + " (('l', 'k'), 24),\n", + " (('m', 'd'), 24),\n", + " (('s', 'w'), 24),\n", + " (('s', 'n'), 24),\n", + " (('h', 'd'), 24),\n", + " (('w', 'h'), 23),\n", + " (('y', 'j'), 23),\n", + " (('y', 'y'), 23),\n", + " (('r', 'z'), 23),\n", + " (('d', 'w'), 23),\n", + " (('w', 'r'), 22),\n", + " (('t', 'n'), 22),\n", + " (('l', 'f'), 22),\n", + " (('y', 'h'), 22),\n", + " (('r', 'w'), 21),\n", + " (('s', 'b'), 21),\n", + " (('m', 'n'), 20),\n", + " (('f', 'l'), 20),\n", + " (('w', 's'), 20),\n", + " (('k', 'k'), 20),\n", + " (('h', 'z'), 20),\n", + " (('g', 'd'), 19),\n", + " (('l', 'h'), 19),\n", + " (('n', 'm'), 19),\n", + " (('x', 'z'), 19),\n", + " (('u', 'f'), 19),\n", + " (('f', 't'), 18),\n", + " (('l', 'r'), 18),\n", + " (('p', 't'), 17),\n", + " (('t', 'c'), 17),\n", + " (('k', 't'), 17),\n", + " (('d', 'v'), 17),\n", + " (('u', 'p'), 16),\n", + " (('p', 'l'), 16),\n", + " (('l', 'w'), 16),\n", + " (('p', 's'), 16),\n", + " (('o', 'j'), 16),\n", + " (('r', 'q'), 16),\n", + " (('y', 'p'), 15),\n", + " (('l', 'p'), 15),\n", + " (('t', 'v'), 15),\n", + " (('r', 'p'), 14),\n", + " (('l', 'n'), 14),\n", + " (('e', 'q'), 14),\n", + " (('f', 'y'), 14),\n", + " (('s', 'v'), 14),\n", + " (('u', 'j'), 14),\n", + " (('v', 'l'), 14),\n", + " (('q', 'a'), 13),\n", + " (('u', 'y'), 13),\n", + " (('q', 'i'), 13),\n", + " (('w', 'l'), 13),\n", + " (('p', 'y'), 12),\n", + " (('y', 'f'), 12),\n", + " (('c', 'q'), 11),\n", + " (('j', 'r'), 11),\n", + " (('n', 'w'), 11),\n", + " (('n', 'f'), 11),\n", + " (('t', 'w'), 11),\n", + " (('m', 'z'), 11),\n", + " (('u', 'o'), 10),\n", + " (('f', 'u'), 10),\n", + " (('l', 'z'), 10),\n", + " (('h', 'w'), 10),\n", + " (('u', 'q'), 10),\n", + " (('j', 'y'), 10),\n", + " (('s', 'z'), 10),\n", + " (('s', 'd'), 9),\n", + " (('j', 'l'), 9),\n", + " (('d', 'j'), 9),\n", + " (('k', 'm'), 9),\n", + " (('r', 'f'), 9),\n", + " (('h', 'j'), 9),\n", + " (('v', 'n'), 8),\n", + " (('n', 'b'), 8),\n", + " (('i', 'w'), 8),\n", + " (('h', 'b'), 8),\n", + " (('b', 's'), 8),\n", + " (('w', 't'), 8),\n", + " (('w', 'd'), 8),\n", + " (('v', 'v'), 7),\n", + " (('v', 'u'), 7),\n", + " (('j', 's'), 7),\n", + " (('m', 'j'), 7),\n", + " (('f', 's'), 6),\n", + " (('l', 'g'), 6),\n", + " (('l', 'j'), 6),\n", + " (('j', 'w'), 6),\n", + " (('n', 'x'), 6),\n", + " (('y', 'q'), 6),\n", + " (('w', 'k'), 6),\n", + " (('g', 'm'), 6),\n", + " (('x', 'u'), 5),\n", + " (('m', 'h'), 5),\n", + " (('m', 'l'), 5),\n", + " (('j', 'm'), 5),\n", + " (('c', 's'), 5),\n", + " (('j', 'v'), 5),\n", + " (('n', 'p'), 5),\n", + " (('d', 'f'), 5),\n", + " (('x', 'd'), 5),\n", + " (('z', 'b'), 4),\n", + " (('f', 'n'), 4),\n", + " (('x', 'c'), 4),\n", + " (('m', 't'), 4),\n", + " (('t', 'm'), 4),\n", + " (('z', 'n'), 4),\n", + " (('z', 't'), 4),\n", + " (('p', 'u'), 4),\n", + " (('c', 'z'), 4),\n", + " (('b', 'n'), 4),\n", + " (('z', 's'), 4),\n", + " (('f', 'w'), 4),\n", + " (('d', 't'), 4),\n", + " (('j', 'd'), 4),\n", + " (('j', 'c'), 4),\n", + " (('y', 'w'), 4),\n", + " (('v', 'k'), 3),\n", + " (('x', 'w'), 3),\n", + " (('t', 'j'), 3),\n", + " (('c', 'j'), 3),\n", + " (('q', 'w'), 3),\n", + " (('g', 'b'), 3),\n", + " (('o', 'q'), 3),\n", + " (('r', 'x'), 3),\n", + " (('d', 'c'), 3),\n", + " (('g', 'j'), 3),\n", + " (('x', 'f'), 3),\n", + " (('z', 'w'), 3),\n", + " (('d', 'k'), 3),\n", + " (('u', 'u'), 3),\n", + " (('m', 'v'), 3),\n", + " (('c', 'x'), 3),\n", + " (('l', 'q'), 3),\n", + " (('p', 'b'), 2),\n", + " (('t', 'g'), 2),\n", + " (('q', 's'), 2),\n", + " (('t', 'x'), 2),\n", + " (('f', 'k'), 2),\n", + " (('b', 't'), 2),\n", + " (('j', 'n'), 2),\n", + " (('k', 'c'), 2),\n", + " (('z', 'k'), 2),\n", + " (('s', 'j'), 2),\n", + " (('s', 'f'), 2),\n", + " (('z', 'j'), 2),\n", + " (('n', 'q'), 2),\n", + " (('f', 'z'), 2),\n", + " (('h', 'g'), 2),\n", + " (('w', 'w'), 2),\n", + " (('k', 'j'), 2),\n", + " (('j', 'k'), 2),\n", + " (('w', 'm'), 2),\n", + " (('z', 'c'), 2),\n", + " (('z', 'v'), 2),\n", + " (('w', 'f'), 2),\n", + " (('q', 'm'), 2),\n", + " (('k', 'z'), 2),\n", + " (('j', 'j'), 2),\n", + " (('z', 'p'), 2),\n", + " (('j', 't'), 2),\n", + " (('k', 'b'), 2),\n", + " (('m', 'w'), 2),\n", + " (('h', 'f'), 2),\n", + " (('c', 'g'), 2),\n", + " (('t', 'f'), 2),\n", + " (('h', 'c'), 2),\n", + " (('q', 'o'), 2),\n", + " (('k', 'd'), 2),\n", + " (('k', 'v'), 2),\n", + " (('s', 'g'), 2),\n", + " (('z', 'd'), 2),\n", + " (('q', 'r'), 1),\n", + " (('d', 'z'), 1),\n", + " (('p', 'j'), 1),\n", + " (('q', 'l'), 1),\n", + " (('p', 'f'), 1),\n", + " (('q', 'e'), 1),\n", + " (('b', 'c'), 1),\n", + " (('c', 'd'), 1),\n", + " (('m', 'f'), 1),\n", + " (('p', 'n'), 1),\n", + " (('w', 'b'), 1),\n", + " (('p', 'c'), 1),\n", + " (('h', 'p'), 1),\n", + " (('f', 'h'), 1),\n", + " (('b', 'j'), 1),\n", + " (('f', 'g'), 1),\n", + " (('z', 'g'), 1),\n", + " (('c', 'p'), 1),\n", + " (('p', 'k'), 1),\n", + " (('p', 'm'), 1),\n", + " (('x', 'n'), 1),\n", + " (('s', 'q'), 1),\n", + " (('k', 'f'), 1),\n", + " (('m', 'k'), 1),\n", + " (('x', 'h'), 1),\n", + " (('g', 'f'), 1),\n", + " (('v', 'b'), 1),\n", + " (('j', 'p'), 1),\n", + " (('g', 'z'), 1),\n", + " (('v', 'd'), 1),\n", + " (('d', 'b'), 1),\n", + " (('v', 'h'), 1),\n", + " (('h', 'h'), 1),\n", + " (('g', 'v'), 1),\n", + " (('d', 'q'), 1),\n", + " (('x', 'b'), 1),\n", + " (('w', 'z'), 1),\n", + " (('h', 'q'), 1),\n", + " (('j', 'b'), 1),\n", + " (('x', 'm'), 1),\n", + " (('w', 'g'), 1),\n", + " (('t', 'b'), 1),\n", + " (('z', 'x'), 1)]" + ] + }, + "execution_count": 10, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "sorted(b.items(), key = lambda kv: -kv[1])" + ] + }, + { + "cell_type": "code", + "execution_count": 11, + "metadata": {}, + "outputs": [], + "source": [ + "import torch" + ] + }, + { + "cell_type": "code", + "execution_count": 365, + "metadata": {}, + "outputs": [], + "source": [ + "N = torch.zeros((27, 27), dtype=torch.int32)" + ] + }, + { + "cell_type": "code", + "execution_count": 366, + "metadata": {}, + "outputs": [], + "source": [ + "chars = sorted(list(set(''.join(words))))\n", + "stoi = {s:i+1 for i,s in enumerate(chars)}\n", + "stoi['.'] = 0\n", + "itos = {i:s for s,i in stoi.items()}" + ] + }, + { + "cell_type": "code", + "execution_count": 367, + "metadata": {}, + "outputs": [], + "source": [ + "\n", + "for w in words:\n", + " chs = ['.'] + list(w) + ['.']\n", + " for ch1, ch2 in zip(chs, chs[1:]):\n", + " ix1 = stoi[ch1]\n", + " ix2 = stoi[ch2]\n", + " N[ix1, ix2] += 1\n", + " " + ] + }, + { + "cell_type": "code", + "execution_count": 368, + "metadata": {}, + "outputs": [ + { + "data": { + "image/png": "\n", + "text/plain": [ + "
" + ] + }, + "metadata": { + "needs_background": "light" + }, + "output_type": "display_data" + } + ], + "source": [ + "import matplotlib.pyplot as plt\n", + "%matplotlib inline\n", + "\n", + "plt.figure(figsize=(16,16))\n", + "plt.imshow(N, cmap='Blues')\n", + "for i in range(27):\n", + " for j in range(27):\n", + " chstr = itos[i] + itos[j]\n", + " plt.text(j, i, chstr, ha=\"center\", va=\"bottom\", color='gray')\n", + " plt.text(j, i, N[i, j].item(), ha=\"center\", va=\"top\", color='gray')\n", + "plt.axis('off');" + ] + }, + { + "cell_type": "code", + "execution_count": 159, + "metadata": {}, + "outputs": [ + { + "data": { + "text/plain": [ + "tensor([1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1.,\n", + " 1., 1., 1., 1., 1., 1., 1., 1., 1.])" + ] + }, + "execution_count": 159, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "N[0]" + ] + }, + { + "cell_type": "code", + "execution_count": 17, + "metadata": {}, + "outputs": [ + { + "data": { + "text/plain": [ + "tensor([0.0000, 0.1377, 0.0408, 0.0481, 0.0528, 0.0478, 0.0130, 0.0209, 0.0273,\n", + " 0.0184, 0.0756, 0.0925, 0.0491, 0.0792, 0.0358, 0.0123, 0.0161, 0.0029,\n", + " 0.0512, 0.0642, 0.0408, 0.0024, 0.0117, 0.0096, 0.0042, 0.0167, 0.0290])" + ] + }, + "execution_count": 17, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "p = N[0].float()\n", + "p = p / p.sum()\n", + "p" + ] + }, + { + "cell_type": "code", + "execution_count": 18, + "metadata": {}, + "outputs": [ + { + "data": { + "text/plain": [ + "'m'" + ] + }, + "execution_count": 18, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "g = torch.Generator().manual_seed(2147483647)\n", + "ix = torch.multinomial(p, num_samples=1, replacement=True, generator=g).item()\n", + "itos[ix]" + ] + }, + { + "cell_type": "code", + "execution_count": 19, + "metadata": {}, + "outputs": [ + { + "data": { + "text/plain": [ + "tensor([0.6064, 0.3033, 0.0903])" + ] + }, + "execution_count": 19, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "g = torch.Generator().manual_seed(2147483647)\n", + "p = torch.rand(3, generator=g)\n", + "p = p / p.sum()\n", + "p" + ] + }, + { + "cell_type": "code", + "execution_count": 20, + "metadata": {}, + "outputs": [ + { + "data": { + "text/plain": [ + "tensor([1, 1, 2, 0, 0, 2, 1, 1, 0, 0, 0, 1, 1, 0, 0, 1, 1, 0, 0, 1, 0, 2, 0, 0,\n", + " 1, 0, 0, 1, 0, 0, 0, 1, 1, 1, 0, 1, 1, 0, 0, 1, 1, 1, 0, 1, 1, 0, 1, 1,\n", + " 0, 2, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 1, 0, 0, 0, 0, 0, 0, 0, 0, 1, 0,\n", + " 0, 1, 0, 0, 0, 0, 0, 0, 0, 1, 2, 0, 0, 0, 0, 0, 0, 1, 0, 0, 2, 0, 1, 0,\n", + " 0, 1, 1, 1])" + ] + }, + "execution_count": 20, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "torch.multinomial(p, num_samples=100, replacement=True, generator=g)" + ] + }, + { + "cell_type": "code", + "execution_count": 21, + "metadata": {}, + "outputs": [ + { + "data": { + "text/plain": [ + "torch.Size([3])" + ] + }, + "execution_count": 21, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "p.shape" + ] + }, + { + "cell_type": "code", + "execution_count": 30, + "metadata": {}, + "outputs": [ + { + "data": { + "text/plain": [ + "torch.Size([27, 27])" + ] + }, + "execution_count": 30, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "P.shape" + ] + }, + { + "cell_type": "code", + "execution_count": 29, + "metadata": {}, + "outputs": [ + { + "data": { + "text/plain": [ + "torch.Size([27, 1])" + ] + }, + "execution_count": 29, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "P.sum(1, keepdim=True).shape" + ] + }, + { + "cell_type": "code", + "execution_count": 24, + "metadata": {}, + "outputs": [], + "source": [ + "# 27, 27\n", + "# 27, 1" + ] + }, + { + "cell_type": "code", + "execution_count": 28, + "metadata": {}, + "outputs": [ + { + "data": { + "text/plain": [ + "torch.Size([27])" + ] + }, + "execution_count": 28, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "P.sum(1).shape" + ] + }, + { + "cell_type": "code", + "execution_count": 26, + "metadata": {}, + "outputs": [], + "source": [ + "# 27, 27\n", + "# 1, 27" + ] + }, + { + "cell_type": "code", + "execution_count": 310, + "metadata": {}, + "outputs": [], + "source": [ + "P = (N+1).float()\n", + "P /= P.sum(1, keepdims=True)" + ] + }, + { + "cell_type": "code", + "execution_count": 164, + "metadata": {}, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "mor.\n", + "axx.\n", + "minaymoryles.\n", + "kondlaisah.\n", + "anchshizarie.\n" + ] + } + ], + "source": [ + "g = torch.Generator().manual_seed(2147483647)\n", + "\n", + "for i in range(5):\n", + " \n", + " out = []\n", + " ix = 0\n", + " while True:\n", + " p = P[ix]\n", + " ix = torch.multinomial(p, num_samples=1, replacement=True, generator=g).item()\n", + " out.append(itos[ix])\n", + " if ix == 0:\n", + " break\n", + " print(''.join(out))" + ] + }, + { + "cell_type": "code", + "execution_count": 165, + "metadata": {}, + "outputs": [], + "source": [ + "# GOAL: maximize likelihood of the data w.r.t. model parameters (statistical modeling)\n", + "# equivalent to maximizing the log likelihood (because log is monotonic)\n", + "# equivalent to minimizing the negative log likelihood\n", + "# equivalent to minimizing the average negative log likelihood\n", + "\n", + "# log(a*b*c) = log(a) + log(b) + log(c)" + ] + }, + { + "cell_type": "code", + "execution_count": 435, + "metadata": {}, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "log_likelihood=tensor(-564996.8125, grad_fn=)\n", + "nll=tensor(564996.8125, grad_fn=)\n", + "2.476470470428467\n" + ] + } + ], + "source": [ + "log_likelihood = 0.0\n", + "n = 0\n", + "\n", + "for w in words:\n", + "#for w in [\"andrejq\"]:\n", + " chs = ['.'] + list(w) + ['.']\n", + " for ch1, ch2 in zip(chs, chs[1:]):\n", + " ix1 = stoi[ch1]\n", + " ix2 = stoi[ch2]\n", + " prob = P[ix1, ix2]\n", + " logprob = torch.log(prob)\n", + " log_likelihood += logprob\n", + " n += 1\n", + " #print(f'{ch1}{ch2}: {prob:.4f} {logprob:.4f}')\n", + "\n", + "print(f'{log_likelihood=}')\n", + "nll = -log_likelihood\n", + "print(f'{nll=}')\n", + "print(f'{nll/n}')" + ] + }, + { + "cell_type": "code", + "execution_count": 449, + "metadata": {}, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + ". e\n", + "e m\n", + "m m\n", + "m a\n", + "a .\n" + ] + } + ], + "source": [ + "# create the training set of bigrams (x,y)\n", + "xs, ys = [], []\n", + "\n", + "for w in words[:1]:\n", + " chs = ['.'] + list(w) + ['.']\n", + " for ch1, ch2 in zip(chs, chs[1:]):\n", + " ix1 = stoi[ch1]\n", + " ix2 = stoi[ch2]\n", + " print(ch1, ch2)\n", + " xs.append(ix1)\n", + " ys.append(ix2)\n", + " \n", + "xs = torch.tensor(xs)\n", + "ys = torch.tensor(ys)" + ] + }, + { + "cell_type": "code", + "execution_count": 450, + "metadata": {}, + "outputs": [ + { + "data": { + "text/plain": [ + "tensor([ 0, 5, 13, 13, 1])" + ] + }, + "execution_count": 450, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "xs" + ] + }, + { + "cell_type": "code", + "execution_count": 451, + "metadata": {}, + "outputs": [ + { + "data": { + "text/plain": [ + "tensor([ 5, 13, 13, 1, 0])" + ] + }, + "execution_count": 451, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "ys" + ] + }, + { + "cell_type": "code", + "execution_count": 487, + "metadata": {}, + "outputs": [ + { + "data": { + "text/plain": [ + "tensor([[1., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.,\n", + " 0., 0., 0., 0., 0., 0., 0., 0., 0.],\n", + " [0., 0., 0., 0., 0., 1., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.,\n", + " 0., 0., 0., 0., 0., 0., 0., 0., 0.],\n", + " [0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 1., 0., 0., 0., 0.,\n", + " 0., 0., 0., 0., 0., 0., 0., 0., 0.],\n", + " [0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 1., 0., 0., 0., 0.,\n", + " 0., 0., 0., 0., 0., 0., 0., 0., 0.],\n", + " [0., 1., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.,\n", + " 0., 0., 0., 0., 0., 0., 0., 0., 0.]])" + ] + }, + "execution_count": 487, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "import torch.nn.functional as F\n", + "xenc = F.one_hot(xs, num_classes=27).float()\n", + "xenc" + ] + }, + { + "cell_type": "code", + "execution_count": 488, + "metadata": {}, + "outputs": [ + { + "data": { + "text/plain": [ + "torch.Size([5, 27])" + ] + }, + "execution_count": 488, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "xenc.shape" + ] + }, + { + "cell_type": "code", + "execution_count": 489, + "metadata": {}, + "outputs": [ + { + "data": { + "text/plain": [ + "" + ] + }, + "execution_count": 489, + "metadata": {}, + "output_type": "execute_result" + }, + { + "data": { + "image/png": "iVBORw0KGgoAAAANSUhEUgAAAWoAAABdCAYAAACM0CxCAAAAOXRFWHRTb2Z0d2FyZQBNYXRwbG90bGliIHZlcnNpb24zLjMuMiwgaHR0cHM6Ly9tYXRwbG90bGliLm9yZy8vihELAAAACXBIWXMAAAsTAAALEwEAmpwYAAAGsklEQVR4nO3dT4hdZxnH8e/PcVpJ20Vrq7RJNFW6KS5SGbqJSClo/yhGF0oDSruKCwspCFrd2I0goqUbEaINVKwWoVWDFGLRFnUT88eQNh0aQ4k2JiTVLtoKNrZ9XNwbHNM7mTs459y3934/EObOuWfmPE/ey2/eeeecc1NVSJLa9Y5JFyBJujCDWpIaZ1BLUuMMaklqnEEtSY0zqCWpce/s4pteecVcbdo4P/b+Rw+v66IMSXrb+Bf/5Gy9llHPdRLUmzbO88c9G8fe/5ZrNndRhiS9beyt3yz73FhLH0luTfJckmNJ7l2zyiRJK1oxqJPMAd8DbgOuB7Ylub7rwiRJA+PMqG8EjlXV81V1FngE2NptWZKkc8YJ6vXAC0s+PzHcJknqwThBPeqvkG+5k1OS7Un2J9n/4j/e+P8rkyQB4wX1CWDpKRwbgJPn71RVO6tqoaoWrnr33FrVJ0kzb5yg3gdcl+TaJBcBdwC7uy1LknTOiudRV9XrSe4G9gBzwK6qOtJ5ZZIkYMwLXqrqceDxjmuRJI3gvT4kqXGdXEJ+9PC6mbwsfM/JQ6vafxb/jyStnjNqSWqcQS1JjTOoJalxBrUkNc6glqTGGdSS1DiDWpIaZ1BLUuMMaklqnEEtSY0zqCWpcQa1JDWuk5syzSpvstSO1d4gCxw/tcsZtSQ1bsWgTrIxyZNJFpMcSbKjj8IkSQPjLH28Dny5qg4muQw4kOSJqnq249okSYwxo66qU1V1cPj4FWARWN91YZKkgVWtUSfZBNwA7O2kGknSW4x91keSS4FHgXuq6uURz28HtgO8i3VrVqAkzbqxZtRJ5hmE9MNV9diofapqZ1UtVNXCPBevZY2SNNPGOesjwIPAYlXd331JkqSlxplRbwG+ANyc5NDw3+0d1yVJGlpxjbqq/gCkh1okSSN4ZaIkNc6glqTGGdSS1DiDWpIaZ1BLUuMMaklqnEEtSY0zqCWpcQa1JDXOoJakxhnUktQ4g1qSGmdQS1LjDGpJatzYb8XVpT0nD636a265ZvOa16Hp4etD08QZtSQ1buygTjKX5E9JftVlQZKk/7WaGfUOYLGrQiRJo437LuQbgE8AP+y2HEnS+cadUT8AfAV4s7tSJEmjrBjUST4JnKmqAyvstz3J/iT7/81ra1agJM26cWbUW4BPJTkOPALcnOTH5+9UVTuraqGqFua5eI3LlKTZtWJQV9XXqmpDVW0C7gB+W1Wf77wySRLgedSS1LxVXZlYVU8BT3VSiSRpJGfUktS4VNXaf9PkReAvI566Evj7mh+wffY9W+x7tqxV3++vqqtGPdFJUC8nyf6qWujtgI2w79li37Olj75d+pCkxhnUktS4voN6Z8/Ha4V9zxb7ni2d993rGrUkafVc+pCkxvUS1EluTfJckmNJ7u3jmC1IcjzJ00kOJdk/6Xq6lGRXkjNJnlmy7YokTyT58/Dj5ZOssQvL9H1fkr8Nx/1QktsnWeNaS7IxyZNJFpMcSbJjuH2qx/sCfXc+3p0vfSSZA44CHwNOAPuAbVX1bKcHbsDwRlYLVTX155Ym+SjwKvCjqvrQcNu3gZeq6lvDH9CXV9VXJ1nnWlum7/uAV6vqO5OsrStJrgaurqqDSS4DDgCfBu5iisf7An1/jo7Hu48Z9Y3Asap6vqrOMrgD39YejqseVdXvgJfO27wVeGj4+CEGL+qpskzfU62qTlXVweHjVxi889N6pny8L9B35/oI6vXAC0s+P0FPzTWggF8nOZBk+6SLmYD3VtUpGLzIgfdMuJ4+3Z3k8HBpZKqWAJZKsgm4AdjLDI33eX1Dx+PdR1BnxLZZOdVkS1V9GLgN+NLw12RNv+8DHwQ2A6eA7060mo4kuRR4FLinql6edD19GdF35+PdR1CfADYu+XwDcLKH405cVZ0cfjwD/JzBMtAsOT1c1zu3vndmwvX0oqpOV9UbVfUm8AOmcNyTzDMIq4er6rHh5qkf71F99zHefQT1PuC6JNcmuYjBmw/s7uG4E5XkkuEfHEhyCfBx4JkLf9XU2Q3cOXx8J/DLCdbSm3NhNfQZpmzckwR4EFisqvuXPDXV471c332Mdy8XvAxPV3kAmAN2VdU3Oz/ohCX5AINZNAzu+/2Tae47yU+BmxjcSew08A3gF8DPgPcBfwU+W1VT9Ye3Zfq+icGvwQUcB754bu12GiT5CPB74Gn++4bXX2ewXju1432BvrfR8Xh7ZaIkNc4rEyWpcQa1JDXOoJakxhnUktQ4g1qSGmdQS1LjDGpJapxBLUmN+w9AXCCNBMImrgAAAABJRU5ErkJggg==\n", + "text/plain": [ + "
" + ] + }, + "metadata": { + "needs_background": "light" + }, + "output_type": "display_data" + } + ], + "source": [ + "plt.imshow(xenc)" + ] + }, + { + "cell_type": "code", + "execution_count": 490, + "metadata": {}, + "outputs": [ + { + "data": { + "text/plain": [ + "torch.float32" + ] + }, + "execution_count": 490, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "xenc.dtype" + ] + }, + { + "cell_type": "code", + "execution_count": 493, + "metadata": {}, + "outputs": [ + { + "data": { + "text/plain": [ + "tensor([[-0.2003, -2.3711, -0.9466, 0.5369, -0.0949, -1.7872, -0.9038, 0.8194,\n", + " 0.6926, 0.0114, -1.5301, 0.6077, -1.2056, 1.8605, -1.3012, -0.0301,\n", + " -2.1611, -0.0538, -0.0133, -0.3629, 0.5254, -0.0080, 1.1602, 1.9851,\n", + " 0.4976, 0.7351, -0.6373],\n", + " [-0.4422, 0.5024, 1.3514, -0.4085, -0.7854, -1.2568, -0.4558, 0.1466,\n", + " -0.4460, 1.2748, -0.6367, 0.6403, -0.5617, -0.3060, 1.6771, -1.4814,\n", + " -2.7395, 0.3876, 0.3970, 1.5577, -0.1995, -0.1397, -1.3045, 0.4294,\n", + " 1.2557, 0.8007, 0.5450],\n", + " [-0.2680, -0.2640, 0.4591, 0.0338, 0.7478, 1.2757, -0.9842, 0.1799,\n", + " 0.0824, -0.5646, -0.3657, -0.8358, -1.7654, 0.5008, -1.7455, -0.8160,\n", + " -2.2721, 0.9713, -1.0734, 0.3115, -0.2506, 0.0757, 0.9332, 1.6536,\n", + " 1.2306, 0.1231, -0.2530],\n", + " [-0.2680, -0.2640, 0.4591, 0.0338, 0.7478, 1.2757, -0.9842, 0.1799,\n", + " 0.0824, -0.5646, -0.3657, -0.8358, -1.7654, 0.5008, -1.7455, -0.8160,\n", + " -2.2721, 0.9713, -1.0734, 0.3115, -0.2506, 0.0757, 0.9332, 1.6536,\n", + " 1.2306, 0.1231, -0.2530],\n", + " [ 0.1949, -1.1315, 0.9479, -0.6382, -0.4422, -0.6489, 0.6576, -1.9004,\n", + " 2.0254, 1.2617, -1.7238, 1.2971, -0.6925, -0.3873, 0.7874, -0.8088,\n", + " 0.5746, -0.5263, -0.5928, 0.1419, 1.0683, -0.1760, -0.3507, -0.5358,\n", + " 0.1470, 1.5682, -1.0393]])" + ] + }, + "execution_count": 493, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "W = torch.randn((27, 1))\n", + "xenc @ W" + ] + }, + { + "cell_type": "code", + "execution_count": 506, + "metadata": {}, + "outputs": [ + { + "data": { + "text/plain": [ + "tensor([[0.0205, 0.0023, 0.0097, 0.0428, 0.0228, 0.0042, 0.0101, 0.0568, 0.0500,\n", + " 0.0253, 0.0054, 0.0460, 0.0075, 0.1609, 0.0068, 0.0243, 0.0029, 0.0237,\n", + " 0.0247, 0.0174, 0.0423, 0.0248, 0.0799, 0.1822, 0.0412, 0.0522, 0.0132],\n", + " [0.0154, 0.0397, 0.0928, 0.0160, 0.0110, 0.0068, 0.0152, 0.0278, 0.0154,\n", + " 0.0860, 0.0127, 0.0456, 0.0137, 0.0177, 0.1286, 0.0055, 0.0016, 0.0354,\n", + " 0.0357, 0.1141, 0.0197, 0.0209, 0.0065, 0.0369, 0.0844, 0.0535, 0.0414],\n", + " [0.0212, 0.0213, 0.0439, 0.0287, 0.0586, 0.0994, 0.0104, 0.0332, 0.0301,\n", + " 0.0158, 0.0192, 0.0120, 0.0047, 0.0458, 0.0048, 0.0123, 0.0029, 0.0733,\n", + " 0.0095, 0.0379, 0.0216, 0.0299, 0.0705, 0.1450, 0.0950, 0.0314, 0.0215],\n", + " [0.0212, 0.0213, 0.0439, 0.0287, 0.0586, 0.0994, 0.0104, 0.0332, 0.0301,\n", + " 0.0158, 0.0192, 0.0120, 0.0047, 0.0458, 0.0048, 0.0123, 0.0029, 0.0733,\n", + " 0.0095, 0.0379, 0.0216, 0.0299, 0.0705, 0.1450, 0.0950, 0.0314, 0.0215],\n", + " [0.0289, 0.0077, 0.0613, 0.0126, 0.0153, 0.0124, 0.0459, 0.0036, 0.1801,\n", + " 0.0839, 0.0042, 0.0869, 0.0119, 0.0161, 0.0522, 0.0106, 0.0422, 0.0140,\n", + " 0.0131, 0.0274, 0.0692, 0.0199, 0.0167, 0.0139, 0.0275, 0.1140, 0.0084]])" + ] + }, + "execution_count": 506, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "logits = xenc @ W # log-counts\n", + "counts = logits.exp() # equivalent N\n", + "probs = counts / counts.sum(1, keepdims=True)\n", + "probs" + ] + }, + { + "cell_type": "code", + "execution_count": 509, + "metadata": {}, + "outputs": [ + { + "data": { + "text/plain": [ + "tensor([0.0205, 0.0023, 0.0097, 0.0428, 0.0228, 0.0042, 0.0101, 0.0568, 0.0500,\n", + " 0.0253, 0.0054, 0.0460, 0.0075, 0.1609, 0.0068, 0.0243, 0.0029, 0.0237,\n", + " 0.0247, 0.0174, 0.0423, 0.0248, 0.0799, 0.1822, 0.0412, 0.0522, 0.0132])" + ] + }, + "execution_count": 509, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "probs[0]" + ] + }, + { + "cell_type": "code", + "execution_count": 510, + "metadata": {}, + "outputs": [ + { + "data": { + "text/plain": [ + "torch.Size([27])" + ] + }, + "execution_count": 510, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "probs[0].shape" + ] + }, + { + "cell_type": "code", + "execution_count": 507, + "metadata": {}, + "outputs": [ + { + "data": { + "text/plain": [ + "tensor(1.)" + ] + }, + "execution_count": 507, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "probs[0].sum()" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "# (5, 27) @ (27, 27) -> (5, 27)" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "# SUMMARY ------------------------------>>>>" + ] + }, + { + "cell_type": "code", + "execution_count": 528, + "metadata": {}, + "outputs": [ + { + "data": { + "text/plain": [ + "tensor([ 0, 5, 13, 13, 1])" + ] + }, + "execution_count": 528, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "xs" + ] + }, + { + "cell_type": "code", + "execution_count": 529, + "metadata": {}, + "outputs": [ + { + "data": { + "text/plain": [ + "tensor([ 5, 13, 13, 1, 0])" + ] + }, + "execution_count": 529, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "ys" + ] + }, + { + "cell_type": "code", + "execution_count": 557, + "metadata": {}, + "outputs": [], + "source": [ + "# randomly initialize 27 neurons' weights. each neuron receives 27 inputs\n", + "g = torch.Generator().manual_seed(2147483647)\n", + "W = torch.randn((27, 27), generator=g)" + ] + }, + { + "cell_type": "code", + "execution_count": 558, + "metadata": {}, + "outputs": [], + "source": [ + "xenc = F.one_hot(xs, num_classes=27).float() # input to the network: one-hot encoding\n", + "logits = xenc @ W # predict log-counts\n", + "counts = logits.exp() # counts, equivalent to N\n", + "probs = counts / counts.sum(1, keepdims=True) # probabilities for next character\n", + "# btw: the last 2 lines here are together called a 'softmax'" + ] + }, + { + "cell_type": "code", + "execution_count": 559, + "metadata": {}, + "outputs": [ + { + "data": { + "text/plain": [ + "torch.Size([5, 27])" + ] + }, + "execution_count": 559, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "probs.shape" + ] + }, + { + "cell_type": "code", + "execution_count": 560, + "metadata": {}, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "--------\n", + "bigram example 1: .e (indexes 0,5)\n", + "input to the neural net: 0\n", + "output probabilities from the neural net: tensor([0.0607, 0.0100, 0.0123, 0.0042, 0.0168, 0.0123, 0.0027, 0.0232, 0.0137,\n", + " 0.0313, 0.0079, 0.0278, 0.0091, 0.0082, 0.0500, 0.2378, 0.0603, 0.0025,\n", + " 0.0249, 0.0055, 0.0339, 0.0109, 0.0029, 0.0198, 0.0118, 0.1537, 0.1459])\n", + "label (actual next character): 5\n", + "probability assigned by the net to the the correct character: 0.012286253273487091\n", + "log likelihood: -4.3992743492126465\n", + "negative log likelihood: 4.3992743492126465\n", + "--------\n", + "bigram example 2: em (indexes 5,13)\n", + "input to the neural net: 5\n", + "output probabilities from the neural net: tensor([0.0290, 0.0796, 0.0248, 0.0521, 0.1989, 0.0289, 0.0094, 0.0335, 0.0097,\n", + " 0.0301, 0.0702, 0.0228, 0.0115, 0.0181, 0.0108, 0.0315, 0.0291, 0.0045,\n", + " 0.0916, 0.0215, 0.0486, 0.0300, 0.0501, 0.0027, 0.0118, 0.0022, 0.0472])\n", + "label (actual next character): 13\n", + "probability assigned by the net to the the correct character: 0.018050702288746834\n", + "log likelihood: -4.014570713043213\n", + "negative log likelihood: 4.014570713043213\n", + "--------\n", + "bigram example 3: mm (indexes 13,13)\n", + "input to the neural net: 13\n", + "output probabilities from the neural net: tensor([0.0312, 0.0737, 0.0484, 0.0333, 0.0674, 0.0200, 0.0263, 0.0249, 0.1226,\n", + " 0.0164, 0.0075, 0.0789, 0.0131, 0.0267, 0.0147, 0.0112, 0.0585, 0.0121,\n", + " 0.0650, 0.0058, 0.0208, 0.0078, 0.0133, 0.0203, 0.1204, 0.0469, 0.0126])\n", + "label (actual next character): 13\n", + "probability assigned by the net to the the correct character: 0.026691533625125885\n", + "log likelihood: -3.623408794403076\n", + "negative log likelihood: 3.623408794403076\n", + "--------\n", + "bigram example 4: ma (indexes 13,1)\n", + "input to the neural net: 13\n", + "output probabilities from the neural net: tensor([0.0312, 0.0737, 0.0484, 0.0333, 0.0674, 0.0200, 0.0263, 0.0249, 0.1226,\n", + " 0.0164, 0.0075, 0.0789, 0.0131, 0.0267, 0.0147, 0.0112, 0.0585, 0.0121,\n", + " 0.0650, 0.0058, 0.0208, 0.0078, 0.0133, 0.0203, 0.1204, 0.0469, 0.0126])\n", + "label (actual next character): 1\n", + "probability assigned by the net to the the correct character: 0.07367684692144394\n", + "log likelihood: -2.6080667972564697\n", + "negative log likelihood: 2.6080667972564697\n", + "--------\n", + "bigram example 5: a. (indexes 1,0)\n", + "input to the neural net: 1\n", + "output probabilities from the neural net: tensor([0.0150, 0.0086, 0.0396, 0.0100, 0.0606, 0.0308, 0.1084, 0.0131, 0.0125,\n", + " 0.0048, 0.1024, 0.0086, 0.0988, 0.0112, 0.0232, 0.0207, 0.0408, 0.0078,\n", + " 0.0899, 0.0531, 0.0463, 0.0309, 0.0051, 0.0329, 0.0654, 0.0503, 0.0091])\n", + "label (actual next character): 0\n", + "probability assigned by the net to the the correct character: 0.0149775305762887\n", + "log likelihood: -4.201204299926758\n", + "negative log likelihood: 4.201204299926758\n", + "=========\n", + "average negative log likelihood, i.e. loss = 3.7693049907684326\n" + ] + } + ], + "source": [ + "\n", + "nlls = torch.zeros(5)\n", + "for i in range(5):\n", + " # i-th bigram:\n", + " x = xs[i].item() # input character index\n", + " y = ys[i].item() # label character index\n", + " print('--------')\n", + " print(f'bigram example {i+1}: {itos[x]}{itos[y]} (indexes {x},{y})')\n", + " print('input to the neural net:', x)\n", + " print('output probabilities from the neural net:', probs[i])\n", + " print('label (actual next character):', y)\n", + " p = probs[i, y]\n", + " print('probability assigned by the net to the the correct character:', p.item())\n", + " logp = torch.log(p)\n", + " print('log likelihood:', logp.item())\n", + " nll = -logp\n", + " print('negative log likelihood:', nll.item())\n", + " nlls[i] = nll\n", + "\n", + "print('=========')\n", + "print('average negative log likelihood, i.e. loss =', nlls.mean().item())" + ] + }, + { + "cell_type": "code", + "execution_count": 561, + "metadata": {}, + "outputs": [], + "source": [ + "# --------- !!! OPTIMIZATION !!! yay --------------" + ] + }, + { + "cell_type": "code", + "execution_count": 565, + "metadata": {}, + "outputs": [ + { + "data": { + "text/plain": [ + "tensor([ 0, 5, 13, 13, 1])" + ] + }, + "execution_count": 565, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "xs" + ] + }, + { + "cell_type": "code", + "execution_count": 566, + "metadata": {}, + "outputs": [ + { + "data": { + "text/plain": [ + "tensor([ 5, 13, 13, 1, 0])" + ] + }, + "execution_count": 566, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "ys" + ] + }, + { + "cell_type": "code", + "execution_count": 580, + "metadata": {}, + "outputs": [], + "source": [ + "# randomly initialize 27 neurons' weights. each neuron receives 27 inputs\n", + "g = torch.Generator().manual_seed(2147483647)\n", + "W = torch.randn((27, 27), generator=g, requires_grad=True)" + ] + }, + { + "cell_type": "code", + "execution_count": 602, + "metadata": {}, + "outputs": [], + "source": [ + "# forward pass\n", + "xenc = F.one_hot(xs, num_classes=27).float() # input to the network: one-hot encoding\n", + "logits = xenc @ W # predict log-counts\n", + "counts = logits.exp() # counts, equivalent to N\n", + "probs = counts / counts.sum(1, keepdims=True) # probabilities for next character\n", + "loss = -probs[torch.arange(5), ys].log().mean()" + ] + }, + { + "cell_type": "code", + "execution_count": 603, + "metadata": {}, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "3.6891887187957764\n" + ] + } + ], + "source": [ + "print(loss.item())" + ] + }, + { + "cell_type": "code", + "execution_count": 604, + "metadata": {}, + "outputs": [], + "source": [ + "# backward pass\n", + "W.grad = None # set to zero the gradient\n", + "loss.backward()" + ] + }, + { + "cell_type": "code", + "execution_count": 605, + "metadata": {}, + "outputs": [], + "source": [ + "W.data += -0.1 * W.grad" + ] + }, + { + "cell_type": "code", + "execution_count": 606, + "metadata": {}, + "outputs": [], + "source": [ + "# --------- !!! OPTIMIZATION !!! yay, but this time actually --------------" + ] + }, + { + "cell_type": "code", + "execution_count": 682, + "metadata": {}, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "number of examples: 228146\n" + ] + } + ], + "source": [ + "# create the dataset\n", + "xs, ys = [], []\n", + "for w in words:\n", + " chs = ['.'] + list(w) + ['.']\n", + " for ch1, ch2 in zip(chs, chs[1:]):\n", + " ix1 = stoi[ch1]\n", + " ix2 = stoi[ch2]\n", + " xs.append(ix1)\n", + " ys.append(ix2)\n", + "xs = torch.tensor(xs)\n", + "ys = torch.tensor(ys)\n", + "num = xs.nelement()\n", + "print('number of examples: ', num)\n", + "\n", + "# initialize the 'network'\n", + "g = torch.Generator().manual_seed(2147483647)\n", + "W = torch.randn((27, 27), generator=g, requires_grad=True)" + ] + }, + { + "cell_type": "code", + "execution_count": 716, + "metadata": {}, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "2.481828451156616\n" + ] + } + ], + "source": [ + "# gradient descent\n", + "for k in range(1):\n", + " \n", + " # forward pass\n", + " xenc = F.one_hot(xs, num_classes=27).float() # input to the network: one-hot encoding\n", + " logits = xenc @ W # predict log-counts\n", + " counts = logits.exp() # counts, equivalent to N\n", + " probs = counts / counts.sum(1, keepdims=True) # probabilities for next character\n", + " loss = -probs[torch.arange(num), ys].log().mean() + 0.01*(W**2).mean()\n", + " print(loss.item())\n", + " \n", + " # backward pass\n", + " W.grad = None # set to zero the gradient\n", + " loss.backward()\n", + " \n", + " # update\n", + " W.data += -50 * W.grad" + ] + }, + { + "cell_type": "code", + "execution_count": 725, + "metadata": {}, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "mor.\n", + "axx.\n", + "minaymoryles.\n", + "kondlaisah.\n", + "anchthizarie.\n" + ] + } + ], + "source": [ + "# finally, sample from the 'neural net' model\n", + "g = torch.Generator().manual_seed(2147483647)\n", + "\n", + "for i in range(5):\n", + " \n", + " out = []\n", + " ix = 0\n", + " while True:\n", + " \n", + " # ----------\n", + " # BEFORE:\n", + " #p = P[ix]\n", + " # ----------\n", + " # NOW:\n", + " xenc = F.one_hot(torch.tensor([ix]), num_classes=27).float()\n", + " logits = xenc @ W # predict log-counts\n", + " counts = logits.exp() # counts, equivalent to N\n", + " p = counts / counts.sum(1, keepdims=True) # probabilities for next character\n", + " # ----------\n", + " \n", + " ix = torch.multinomial(p, num_samples=1, replacement=True, generator=g).item()\n", + " out.append(itos[ix])\n", + " if ix == 0:\n", + " break\n", + " print(''.join(out))" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [] + } + ], + "metadata": { + "kernelspec": { + "display_name": "Python 3", + "language": "python", + "name": "python3" + }, + "language_info": { + "codemirror_mode": { + "name": "ipython", + "version": 3 + }, + "file_extension": ".py", + "mimetype": "text/x-python", + "name": "python", + "nbconvert_exporter": "python", + "pygments_lexer": "ipython3", + "version": "3.8.5" + } + }, + "nbformat": 4, + "nbformat_minor": 4 +} diff --git a/lectures/micrograd/micrograd_lecture_first_half_roughly.ipynb b/lectures/micrograd/micrograd_lecture_first_half_roughly.ipynb new file mode 100644 index 00000000..1a3fa629 --- /dev/null +++ b/lectures/micrograd/micrograd_lecture_first_half_roughly.ipynb @@ -0,0 +1,1278 @@ +{ + "cells": [ + { + "cell_type": "code", + "execution_count": 10, + "metadata": {}, + "outputs": [], + "source": [ + "import math\n", + "import numpy as np\n", + "import matplotlib.pyplot as plt\n", + "%matplotlib inline" + ] + }, + { + "cell_type": "code", + "execution_count": 11, + "metadata": {}, + "outputs": [], + "source": [ + "def f(x):\n", + " return 3*x**2 - 4*x + 5" + ] + }, + { + "cell_type": "code", + "execution_count": 12, + "metadata": {}, + "outputs": [ + { + "data": { + "text/plain": [ + "20.0" + ] + }, + "execution_count": 12, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "f(3.0)" + ] + }, + { + "cell_type": "code", + "execution_count": 15, + "metadata": {}, + "outputs": [ + { + "data": { + "text/plain": [ + "[]" + ] + }, + "execution_count": 15, + "metadata": {}, + "output_type": "execute_result" + }, + { + "data": { + "image/png": "\n", + "text/plain": [ + "
" + ] + }, + "metadata": { + "needs_background": "light" + }, + "output_type": "display_data" + } + ], + "source": [ + "xs = np.arange(-5, 5, 0.25)\n", + "ys = f(xs)\n", + "plt.plot(xs, ys)" + ] + }, + { + "cell_type": "code", + "execution_count": 42, + "metadata": {}, + "outputs": [ + { + "data": { + "text/plain": [ + "2.999378523327323e-06" + ] + }, + "execution_count": 42, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "h = 0.000001\n", + "x = 2/3\n", + "(f(x + h) - f(x))/h" + ] + }, + { + "cell_type": "code", + "execution_count": 43, + "metadata": {}, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "4.0\n" + ] + } + ], + "source": [ + "# les get more complex\n", + "a = 2.0\n", + "b = -3.0\n", + "c = 10.0\n", + "d = a*b + c\n", + "print(d)" + ] + }, + { + "cell_type": "code", + "execution_count": 50, + "metadata": {}, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "d1 4.0\n", + "d2 4.0001\n", + "slope 0.9999999999976694\n" + ] + } + ], + "source": [ + "h = 0.0001\n", + "\n", + "# inputs\n", + "a = 2.0\n", + "b = -3.0\n", + "c = 10.0\n", + "\n", + "d1 = a*b + c\n", + "c += h\n", + "d2 = a*b + c\n", + "\n", + "print('d1', d1)\n", + "print('d2', d2)\n", + "print('slope', (d2 - d1)/h)\n" + ] + }, + { + "cell_type": "code", + "execution_count": 257, + "metadata": {}, + "outputs": [ + { + "data": { + "text/plain": [ + "Value(data=-8.0)" + ] + }, + "execution_count": 257, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "class Value:\n", + " \n", + " def __init__(self, data, _children=(), _op='', label=''):\n", + " self.data = data\n", + " self.grad = 0.0\n", + " self._backward = lambda: None\n", + " self._prev = set(_children)\n", + " self._op = _op\n", + " self.label = label\n", + "\n", + " def __repr__(self):\n", + " return f\"Value(data={self.data})\"\n", + " \n", + " def __add__(self, other):\n", + " out = Value(self.data + other.data, (self, other), '+')\n", + " \n", + " def _backward():\n", + " self.grad += 1.0 * out.grad\n", + " other.grad += 1.0 * out.grad\n", + " out._backward = _backward\n", + " \n", + " return out\n", + "\n", + " def __mul__(self, other):\n", + " out = Value(self.data * other.data, (self, other), '*')\n", + " \n", + " def _backward():\n", + " self.grad += other.data * out.grad\n", + " other.grad += self.data * out.grad\n", + " out._backward = _backward\n", + " \n", + " return out\n", + " \n", + " def tanh(self):\n", + " x = self.data\n", + " t = (math.exp(2*x) - 1)/(math.exp(2*x) + 1)\n", + " out = Value(t, (self, ), 'tanh')\n", + " \n", + " def _backward():\n", + " self.grad += (1 - t**2) * out.grad\n", + " out._backward = _backward\n", + " \n", + " return out\n", + " \n", + " def backward(self):\n", + " \n", + " topo = []\n", + " visited = set()\n", + " def build_topo(v):\n", + " if v not in visited:\n", + " visited.add(v)\n", + " for child in v._prev:\n", + " build_topo(child)\n", + " topo.append(v)\n", + " build_topo(self)\n", + " \n", + " self.grad = 1.0\n", + " for node in reversed(topo):\n", + " node._backward()\n", + "\n", + "\n", + "a = Value(2.0, label='a')\n", + "b = Value(-3.0, label='b')\n", + "c = Value(10.0, label='c')\n", + "e = a*b; e.label = 'e'\n", + "d = e + c; d.label = 'd'\n", + "f = Value(-2.0, label='f')\n", + "L = d * f; L.label = 'L'\n", + "L" + ] + }, + { + "cell_type": "code", + "execution_count": 139, + "metadata": {}, + "outputs": [], + "source": [ + "from graphviz import Digraph\n", + "\n", + "def trace(root):\n", + " # builds a set of all nodes and edges in a graph\n", + " nodes, edges = set(), set()\n", + " def build(v):\n", + " if v not in nodes:\n", + " nodes.add(v)\n", + " for child in v._prev:\n", + " edges.add((child, v))\n", + " build(child)\n", + " build(root)\n", + " return nodes, edges\n", + "\n", + "def draw_dot(root):\n", + " dot = Digraph(format='svg', graph_attr={'rankdir': 'LR'}) # LR = left to right\n", + " \n", + " nodes, edges = trace(root)\n", + " for n in nodes:\n", + " uid = str(id(n))\n", + " # for any value in the graph, create a rectangular ('record') node for it\n", + " dot.node(name = uid, label = \"{ %s | data %.4f | grad %.4f }\" % (n.label, n.data, n.grad), shape='record')\n", + " if n._op:\n", + " # if this value is a result of some operation, create an op node for it\n", + " dot.node(name = uid + n._op, label = n._op)\n", + " # and connect this node to it\n", + " dot.edge(uid + n._op, uid)\n", + "\n", + " for n1, n2 in edges:\n", + " # connect n1 to the op node of n2\n", + " dot.edge(str(id(n1)), str(id(n2)) + n2._op)\n", + "\n", + " return dot\n" + ] + }, + { + "cell_type": "code", + "execution_count": 144, + "metadata": {}, + "outputs": [ + { + "data": { + "image/svg+xml": [ + "\n", + "\n", + "\n", + "\n", + "\n", + "\n", + "\n", + "\n", + "\n", + "140306649871024\n", + "\n", + "a\n", + "\n", + "data 2.0000\n", + "\n", + "grad 6.0000\n", + "\n", + "\n", + "\n", + "140306649871264*\n", + "\n", + "*\n", + "\n", + "\n", + "\n", + "140306649871024->140306649871264*\n", + "\n", + "\n", + "\n", + "\n", + "\n", + "140306649871072\n", + "\n", + "c\n", + "\n", + "data 10.0000\n", + "\n", + "grad -2.0000\n", + "\n", + "\n", + "\n", + "140306649871744+\n", + "\n", + "+\n", + "\n", + "\n", + "\n", + "140306649871072->140306649871744+\n", + "\n", + "\n", + "\n", + "\n", + "\n", + "140306649871120\n", + "\n", + "b\n", + "\n", + "data -3.0000\n", + "\n", + "grad -4.0000\n", + "\n", + "\n", + "\n", + "140306649871120->140306649871264*\n", + "\n", + "\n", + "\n", + "\n", + "\n", + "140306649871648\n", + "\n", + "L\n", + "\n", + "data -8.0000\n", + "\n", + "grad 1.0000\n", + "\n", + "\n", + "\n", + "140306649871648*\n", + "\n", + "*\n", + "\n", + "\n", + "\n", + "140306649871648*->140306649871648\n", + "\n", + "\n", + "\n", + "\n", + "\n", + "140306649871744\n", + "\n", + "d\n", + "\n", + "data 4.0000\n", + "\n", + "grad -2.0000\n", + "\n", + "\n", + "\n", + "140306649871744->140306649871648*\n", + "\n", + "\n", + "\n", + "\n", + "\n", + "140306649871744+->140306649871744\n", + "\n", + "\n", + "\n", + "\n", + "\n", + "140306649871264\n", + "\n", + "e\n", + "\n", + "data -6.0000\n", + "\n", + "grad -2.0000\n", + "\n", + "\n", + "\n", + "140306649871264->140306649871744+\n", + "\n", + "\n", + "\n", + "\n", + "\n", + "140306649871264*->140306649871264\n", + "\n", + "\n", + "\n", + "\n", + "\n", + "140306649871792\n", + "\n", + "f\n", + "\n", + "data -2.0000\n", + "\n", + "grad 4.0000\n", + "\n", + "\n", + "\n", + "140306649871792->140306649871648*\n", + "\n", + "\n", + "\n", + "\n", + "\n" + ], + "text/plain": [ + "" + ] + }, + "execution_count": 144, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "draw_dot(L)" + ] + }, + { + "cell_type": "code", + "execution_count": 145, + "metadata": {}, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "-7.286496\n" + ] + } + ], + "source": [ + "a.data += 0.01 * a.grad\n", + "b.data += 0.01 * b.grad\n", + "c.data += 0.01 * c.grad\n", + "f.data += 0.01 * f.grad\n", + "\n", + "e = a * b\n", + "d = e + c\n", + "L = d * f\n", + "\n", + "print(L.data)\n" + ] + }, + { + "cell_type": "code", + "execution_count": 136, + "metadata": {}, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "-3.9999999999995595\n" + ] + } + ], + "source": [ + "def lol():\n", + " \n", + " h = 0.001\n", + " \n", + " a = Value(2.0, label='a')\n", + " b = Value(-3.0, label='b')\n", + " c = Value(10.0, label='c')\n", + " e = a*b; e.label = 'e'\n", + " d = e + c; d.label = 'd'\n", + " f = Value(-2.0, label='f')\n", + " L = d * f; L.label = 'L'\n", + " L1 = L.data\n", + " \n", + " a = Value(2.0, label='a')\n", + " b = Value(-3.0, label='b')\n", + " b.data += h\n", + " c = Value(10.0, label='c')\n", + " e = a*b; e.label = 'e'\n", + " d = e + c; d.label = 'd'\n", + " f = Value(-2.0, label='f')\n", + " L = d * f; L.label = 'L'\n", + " L2 = L.data\n", + " \n", + " print((L2 - L1)/h)\n", + " \n", + "lol()" + ] + }, + { + "cell_type": "code", + "execution_count": 152, + "metadata": {}, + "outputs": [ + { + "data": { + "image/png": "\n", + "text/plain": [ + "
" + ] + }, + "metadata": { + "needs_background": "light" + }, + "output_type": "display_data" + } + ], + "source": [ + "plt.plot(np.arange(-5,5,0.2), np.tanh(np.arange(-5,5,0.2))); plt.grid();" + ] + }, + { + "cell_type": "code", + "execution_count": 241, + "metadata": {}, + "outputs": [], + "source": [ + "# inputs x1,x2\n", + "x1 = Value(2.0, label='x1')\n", + "x2 = Value(0.0, label='x2')\n", + "# weights w1,w2\n", + "w1 = Value(-3.0, label='w1')\n", + "w2 = Value(1.0, label='w2')\n", + "# bias of the neuron\n", + "b = Value(6.8813735870195432, label='b')\n", + "# x1*w1 + x2*w2 + b\n", + "x1w1 = x1*w1; x1w1.label = 'x1*w1'\n", + "x2w2 = x2*w2; x2w2.label = 'x2*w2'\n", + "x1w1x2w2 = x1w1 + x2w2; x1w1x2w2.label = 'x1*w1 + x2*w2'\n", + "n = x1w1x2w2 + b; n.label = 'n'\n", + "o = n.tanh(); o.label = 'o'" + ] + }, + { + "cell_type": "code", + "execution_count": 244, + "metadata": {}, + "outputs": [ + { + "data": { + "image/svg+xml": [ + "\n", + "\n", + "\n", + "\n", + "\n", + "\n", + "\n", + "\n", + "\n", + "140307056976896\n", + "\n", + "w2\n", + "\n", + "data 1.0000\n", + "\n", + "grad 0.0000\n", + "\n", + "\n", + "\n", + "140307056976608*\n", + "\n", + "*\n", + "\n", + "\n", + "\n", + "140307056976896->140307056976608*\n", + "\n", + "\n", + "\n", + "\n", + "\n", + "140307056979488\n", + "\n", + "x1*w1\n", + "\n", + "data -6.0000\n", + "\n", + "grad 0.5000\n", + "\n", + "\n", + "\n", + "140307056977616+\n", + "\n", + "+\n", + "\n", + "\n", + "\n", + "140307056979488->140307056977616+\n", + "\n", + "\n", + "\n", + "\n", + "\n", + "140307056979488*\n", + "\n", + "*\n", + "\n", + "\n", + "\n", + "140307056979488*->140307056979488\n", + "\n", + "\n", + "\n", + "\n", + "\n", + "140307056975936\n", + "\n", + "x1\n", + "\n", + "data 2.0000\n", + "\n", + "grad -1.5000\n", + "\n", + "\n", + "\n", + "140307056975936->140307056979488*\n", + "\n", + "\n", + "\n", + "\n", + "\n", + "140307056975984\n", + "\n", + "x2\n", + "\n", + "data 0.0000\n", + "\n", + "grad 0.5000\n", + "\n", + "\n", + "\n", + "140307056975984->140307056976608*\n", + "\n", + "\n", + "\n", + "\n", + "\n", + "140307056976512\n", + "\n", + "o\n", + "\n", + "data 0.7071\n", + "\n", + "grad 1.0000\n", + "\n", + "\n", + "\n", + "140307056976512tanh\n", + "\n", + "tanh\n", + "\n", + "\n", + "\n", + "140307056976512tanh->140307056976512\n", + "\n", + "\n", + "\n", + "\n", + "\n", + "140307056978576\n", + "\n", + "w1\n", + "\n", + "data -3.0000\n", + "\n", + "grad 1.0000\n", + "\n", + "\n", + "\n", + "140307056978576->140307056979488*\n", + "\n", + "\n", + "\n", + "\n", + "\n", + "140307056978624\n", + "\n", + "b\n", + "\n", + "data 6.8814\n", + "\n", + "grad 0.5000\n", + "\n", + "\n", + "\n", + "140307056976704+\n", + "\n", + "+\n", + "\n", + "\n", + "\n", + "140307056978624->140307056976704+\n", + "\n", + "\n", + "\n", + "\n", + "\n", + "140307056977616\n", + "\n", + "x1*w1 + x2*w2\n", + "\n", + "data -6.0000\n", + "\n", + "grad 0.5000\n", + "\n", + "\n", + "\n", + "140307056977616->140307056976704+\n", + "\n", + "\n", + "\n", + "\n", + "\n", + "140307056977616+->140307056977616\n", + "\n", + "\n", + "\n", + "\n", + "\n", + "140307056976608\n", + "\n", + "x2*w2\n", + "\n", + "data 0.0000\n", + "\n", + "grad 0.5000\n", + "\n", + "\n", + "\n", + "140307056976608->140307056977616+\n", + "\n", + "\n", + "\n", + "\n", + "\n", + "140307056976608*->140307056976608\n", + "\n", + "\n", + "\n", + "\n", + "\n", + "140307056976704\n", + "\n", + "n\n", + "\n", + "data 0.8814\n", + "\n", + "grad 0.5000\n", + "\n", + "\n", + "\n", + "140307056976704->140307056976512tanh\n", + "\n", + "\n", + "\n", + "\n", + "\n", + "140307056976704+->140307056976704\n", + "\n", + "\n", + "\n", + "\n", + "\n" + ], + "text/plain": [ + "" + ] + }, + "execution_count": 244, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "draw_dot(o)" + ] + }, + { + "cell_type": "code", + "execution_count": 243, + "metadata": {}, + "outputs": [], + "source": [ + "o.backward()" + ] + }, + { + "cell_type": "code", + "execution_count": 235, + "metadata": {}, + "outputs": [ + { + "data": { + "text/plain": [ + "[Value(data=6.881373587019543),\n", + " Value(data=2.0),\n", + " Value(data=-3.0),\n", + " Value(data=-6.0),\n", + " Value(data=0.0),\n", + " Value(data=1.0),\n", + " Value(data=0.0),\n", + " Value(data=-6.0),\n", + " Value(data=0.8813735870195432),\n", + " Value(data=0.7071067811865476)]" + ] + }, + "execution_count": 235, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "topo = []\n", + "visited = set()\n", + "def build_topo(v):\n", + " if v not in visited:\n", + " visited.add(v)\n", + " for child in v._prev:\n", + " build_topo(child)\n", + " topo.append(v)\n", + "build_topo(o)\n", + "topo" + ] + }, + { + "cell_type": "code", + "execution_count": 221, + "metadata": {}, + "outputs": [], + "source": [ + "o.grad = 1.0" + ] + }, + { + "cell_type": "code", + "execution_count": 223, + "metadata": {}, + "outputs": [], + "source": [ + "o._backward()" + ] + }, + { + "cell_type": "code", + "execution_count": 225, + "metadata": {}, + "outputs": [], + "source": [ + "n._backward()" + ] + }, + { + "cell_type": "code", + "execution_count": 227, + "metadata": {}, + "outputs": [], + "source": [ + "b._backward()" + ] + }, + { + "cell_type": "code", + "execution_count": 228, + "metadata": {}, + "outputs": [], + "source": [ + "x1w1x2w2._backward()" + ] + }, + { + "cell_type": "code", + "execution_count": 230, + "metadata": {}, + "outputs": [], + "source": [ + "x2w2._backward()\n", + "x1w1._backward()" + ] + }, + { + "cell_type": "code", + "execution_count": 200, + "metadata": {}, + "outputs": [], + "source": [ + "x1.grad = w1.data * x1w1.grad\n", + "w1.grad = x1.data * x1w1.grad" + ] + }, + { + "cell_type": "code", + "execution_count": 198, + "metadata": {}, + "outputs": [], + "source": [ + "x2.grad = w2.data * x2w2.grad\n", + "w2.grad = x2.data * x2w2.grad" + ] + }, + { + "cell_type": "code", + "execution_count": 196, + "metadata": {}, + "outputs": [], + "source": [ + "x1w1.grad = 0.5\n", + "x2w2.grad = 0.5" + ] + }, + { + "cell_type": "code", + "execution_count": 194, + "metadata": {}, + "outputs": [], + "source": [ + "x1w1x2w2.grad = 0.5\n", + "b.grad = 0.5" + ] + }, + { + "cell_type": "code", + "execution_count": 192, + "metadata": {}, + "outputs": [], + "source": [ + "n.grad = 0.5" + ] + }, + { + "cell_type": "code", + "execution_count": 187, + "metadata": {}, + "outputs": [], + "source": [ + "o.grad = 1.0" + ] + }, + { + "cell_type": "code", + "execution_count": 191, + "metadata": {}, + "outputs": [ + { + "data": { + "text/plain": [ + "0.4999999999999999" + ] + }, + "execution_count": 191, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "1 - o.data**2" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "# o = tanh(n)\n", + "# do/dn = 1 - o**2" + ] + }, + { + "cell_type": "code", + "execution_count": 258, + "metadata": {}, + "outputs": [ + { + "data": { + "image/svg+xml": [ + "\n", + "\n", + "\n", + "\n", + "\n", + "\n", + "\n", + "\n", + "\n", + "140307056784240\n", + "\n", + "a\n", + "\n", + "data 3.0000\n", + "\n", + "grad 2.0000\n", + "\n", + "\n", + "\n", + "140307056785008+\n", + "\n", + "+\n", + "\n", + "\n", + "\n", + "140307056784240->140307056785008+\n", + "\n", + "\n", + "\n", + "\n", + "\n", + "140307056785008\n", + "\n", + "b\n", + "\n", + "data 6.0000\n", + "\n", + "grad 1.0000\n", + "\n", + "\n", + "\n", + "140307056785008+->140307056785008\n", + "\n", + "\n", + "\n", + "\n", + "\n" + ], + "text/plain": [ + "" + ] + }, + "execution_count": 258, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "a = Value(3.0, label='a')\n", + "b = a + a ; b.label = 'b'\n", + "b.backward()\n", + "draw_dot(b)" + ] + }, + { + "cell_type": "code", + "execution_count": 259, + "metadata": {}, + "outputs": [ + { + "data": { + "image/svg+xml": [ + "\n", + "\n", + "\n", + "\n", + "\n", + "\n", + "\n", + "\n", + "\n", + "140306525506048\n", + "\n", + "a\n", + "\n", + "data -2.0000\n", + "\n", + "grad -3.0000\n", + "\n", + "\n", + "\n", + "140307056785968+\n", + "\n", + "+\n", + "\n", + "\n", + "\n", + "140306525506048->140307056785968+\n", + "\n", + "\n", + "\n", + "\n", + "\n", + "140306525506912*\n", + "\n", + "*\n", + "\n", + "\n", + "\n", + "140306525506048->140306525506912*\n", + "\n", + "\n", + "\n", + "\n", + "\n", + "140307056785968\n", + "\n", + "e\n", + "\n", + "data 1.0000\n", + "\n", + "grad -6.0000\n", + "\n", + "\n", + "\n", + "140307056783712*\n", + "\n", + "*\n", + "\n", + "\n", + "\n", + "140307056785968->140307056783712*\n", + "\n", + "\n", + "\n", + "\n", + "\n", + "140307056785968+->140307056785968\n", + "\n", + "\n", + "\n", + "\n", + "\n", + "140306525506912\n", + "\n", + "d\n", + "\n", + "data -6.0000\n", + "\n", + "grad 1.0000\n", + "\n", + "\n", + "\n", + "140306525506912->140307056783712*\n", + "\n", + "\n", + "\n", + "\n", + "\n", + "140306525506912*->140306525506912\n", + "\n", + "\n", + "\n", + "\n", + "\n", + "140307056783712\n", + "\n", + "f\n", + "\n", + "data -6.0000\n", + "\n", + "grad 1.0000\n", + "\n", + "\n", + "\n", + "140307056783712*->140307056783712\n", + "\n", + "\n", + "\n", + "\n", + "\n", + "140306525506432\n", + "\n", + "b\n", + "\n", + "data 3.0000\n", + "\n", + "grad -8.0000\n", + "\n", + "\n", + "\n", + "140306525506432->140307056785968+\n", + "\n", + "\n", + "\n", + "\n", + "\n", + "140306525506432->140306525506912*\n", + "\n", + "\n", + "\n", + "\n", + "\n" + ], + "text/plain": [ + "" + ] + }, + "execution_count": 259, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "a = Value(-2.0, label='a')\n", + "b = Value(3.0, label='b')\n", + "d = a * b ; d.label = 'd'\n", + "e = a + b ; e.label = 'e'\n", + "f = d * e ; f.label = 'f'\n", + "\n", + "f.backward()\n", + "\n", + "draw_dot(f)" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [] + } + ], + "metadata": { + "kernelspec": { + "display_name": "Python 3", + "language": "python", + "name": "python3" + }, + "language_info": { + "codemirror_mode": { + "name": "ipython", + "version": 3 + }, + "file_extension": ".py", + "mimetype": "text/x-python", + "name": "python", + "nbconvert_exporter": "python", + "pygments_lexer": "ipython3", + "version": "3.8.5" + } + }, + "nbformat": 4, + "nbformat_minor": 4 +} diff --git a/lectures/micrograd/micrograd_lecture_second_half_roughly.ipynb b/lectures/micrograd/micrograd_lecture_second_half_roughly.ipynb new file mode 100644 index 00000000..67a25a68 --- /dev/null +++ b/lectures/micrograd/micrograd_lecture_second_half_roughly.ipynb @@ -0,0 +1,1273 @@ +{ + "cells": [ + { + "cell_type": "code", + "execution_count": 382, + "metadata": {}, + "outputs": [], + "source": [ + "import math\n", + "import random\n", + "import numpy as np\n", + "import matplotlib.pyplot as plt\n", + "%matplotlib inline" + ] + }, + { + "cell_type": "code", + "execution_count": 11, + "metadata": {}, + "outputs": [], + "source": [ + "def f(x):\n", + " return 3*x**2 - 4*x + 5" + ] + }, + { + "cell_type": "code", + "execution_count": 12, + "metadata": {}, + "outputs": [ + { + "data": { + "text/plain": [ + "20.0" + ] + }, + "execution_count": 12, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "f(3.0)" + ] + }, + { + "cell_type": "code", + "execution_count": 15, + "metadata": {}, + "outputs": [ + { + "data": { + "text/plain": [ + "[]" + ] + }, + "execution_count": 15, + "metadata": {}, + "output_type": "execute_result" + }, + { + "data": { + "image/png": "\n", + "text/plain": [ + "
" + ] + }, + "metadata": { + "needs_background": "light" + }, + "output_type": "display_data" + } + ], + "source": [ + "xs = np.arange(-5, 5, 0.25)\n", + "ys = f(xs)\n", + "plt.plot(xs, ys)" + ] + }, + { + "cell_type": "code", + "execution_count": 42, + "metadata": {}, + "outputs": [ + { + "data": { + "text/plain": [ + "2.999378523327323e-06" + ] + }, + "execution_count": 42, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "h = 0.000001\n", + "x = 2/3\n", + "(f(x + h) - f(x))/h" + ] + }, + { + "cell_type": "code", + "execution_count": 43, + "metadata": {}, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "4.0\n" + ] + } + ], + "source": [ + "# les get more complex\n", + "a = 2.0\n", + "b = -3.0\n", + "c = 10.0\n", + "d = a*b + c\n", + "print(d)" + ] + }, + { + "cell_type": "code", + "execution_count": 50, + "metadata": {}, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "d1 4.0\n", + "d2 4.0001\n", + "slope 0.9999999999976694\n" + ] + } + ], + "source": [ + "h = 0.0001\n", + "\n", + "# inputs\n", + "a = 2.0\n", + "b = -3.0\n", + "c = 10.0\n", + "\n", + "d1 = a*b + c\n", + "c += h\n", + "d2 = a*b + c\n", + "\n", + "print('d1', d1)\n", + "print('d2', d2)\n", + "print('slope', (d2 - d1)/h)\n" + ] + }, + { + "cell_type": "code", + "execution_count": 328, + "metadata": {}, + "outputs": [], + "source": [ + "class Value:\n", + " \n", + " def __init__(self, data, _children=(), _op='', label=''):\n", + " self.data = data\n", + " self.grad = 0.0\n", + " self._backward = lambda: None\n", + " self._prev = set(_children)\n", + " self._op = _op\n", + " self.label = label\n", + "\n", + " def __repr__(self):\n", + " return f\"Value(data={self.data})\"\n", + " \n", + " def __add__(self, other):\n", + " other = other if isinstance(other, Value) else Value(other)\n", + " out = Value(self.data + other.data, (self, other), '+')\n", + " \n", + " def _backward():\n", + " self.grad += 1.0 * out.grad\n", + " other.grad += 1.0 * out.grad\n", + " out._backward = _backward\n", + " \n", + " return out\n", + "\n", + " def __mul__(self, other):\n", + " other = other if isinstance(other, Value) else Value(other)\n", + " out = Value(self.data * other.data, (self, other), '*')\n", + " \n", + " def _backward():\n", + " self.grad += other.data * out.grad\n", + " other.grad += self.data * out.grad\n", + " out._backward = _backward\n", + " \n", + " return out\n", + " \n", + " def __pow__(self, other):\n", + " assert isinstance(other, (int, float)), \"only supporting int/float powers for now\"\n", + " out = Value(self.data**other, (self,), f'**{other}')\n", + "\n", + " def _backward():\n", + " self.grad += other * (self.data ** (other - 1)) * out.grad\n", + " out._backward = _backward\n", + "\n", + " return out\n", + " \n", + " def __rmul__(self, other): # other * self\n", + " return self * other\n", + "\n", + " def __truediv__(self, other): # self / other\n", + " return self * other**-1\n", + "\n", + " def __neg__(self): # -self\n", + " return self * -1\n", + "\n", + " def __sub__(self, other): # self - other\n", + " return self + (-other)\n", + "\n", + " def __radd__(self, other): # other + self\n", + " return self + other\n", + "\n", + " def tanh(self):\n", + " x = self.data\n", + " t = (math.exp(2*x) - 1)/(math.exp(2*x) + 1)\n", + " out = Value(t, (self, ), 'tanh')\n", + " \n", + " def _backward():\n", + " self.grad += (1 - t**2) * out.grad\n", + " out._backward = _backward\n", + " \n", + " return out\n", + " \n", + " def exp(self):\n", + " x = self.data\n", + " out = Value(math.exp(x), (self, ), 'exp')\n", + " \n", + " def _backward():\n", + " self.grad += out.data * out.grad # NOTE: in the video I incorrectly used = instead of +=. Fixed here.\n", + " out._backward = _backward\n", + " \n", + " return out\n", + " \n", + " \n", + " def backward(self):\n", + " \n", + " topo = []\n", + " visited = set()\n", + " def build_topo(v):\n", + " if v not in visited:\n", + " visited.add(v)\n", + " for child in v._prev:\n", + " build_topo(child)\n", + " topo.append(v)\n", + " build_topo(self)\n", + " \n", + " self.grad = 1.0\n", + " for node in reversed(topo):\n", + " node._backward()\n", + "\n" + ] + }, + { + "cell_type": "code", + "execution_count": 139, + "metadata": {}, + "outputs": [], + "source": [ + "from graphviz import Digraph\n", + "\n", + "def trace(root):\n", + " # builds a set of all nodes and edges in a graph\n", + " nodes, edges = set(), set()\n", + " def build(v):\n", + " if v not in nodes:\n", + " nodes.add(v)\n", + " for child in v._prev:\n", + " edges.add((child, v))\n", + " build(child)\n", + " build(root)\n", + " return nodes, edges\n", + "\n", + "def draw_dot(root):\n", + " dot = Digraph(format='svg', graph_attr={'rankdir': 'LR'}) # LR = left to right\n", + " \n", + " nodes, edges = trace(root)\n", + " for n in nodes:\n", + " uid = str(id(n))\n", + " # for any value in the graph, create a rectangular ('record') node for it\n", + " dot.node(name = uid, label = \"{ %s | data %.4f | grad %.4f }\" % (n.label, n.data, n.grad), shape='record')\n", + " if n._op:\n", + " # if this value is a result of some operation, create an op node for it\n", + " dot.node(name = uid + n._op, label = n._op)\n", + " # and connect this node to it\n", + " dot.edge(uid + n._op, uid)\n", + "\n", + " for n1, n2 in edges:\n", + " # connect n1 to the op node of n2\n", + " dot.edge(str(id(n1)), str(id(n2)) + n2._op)\n", + "\n", + " return dot\n" + ] + }, + { + "cell_type": "code", + "execution_count": 318, + "metadata": {}, + "outputs": [], + "source": [ + "# inputs x1,x2\n", + "x1 = Value(2.0, label='x1')\n", + "x2 = Value(0.0, label='x2')\n", + "# weights w1,w2\n", + "w1 = Value(-3.0, label='w1')\n", + "w2 = Value(1.0, label='w2')\n", + "# bias of the neuron\n", + "b = Value(6.8813735870195432, label='b')\n", + "# x1*w1 + x2*w2 + b\n", + "x1w1 = x1*w1; x1w1.label = 'x1*w1'\n", + "x2w2 = x2*w2; x2w2.label = 'x2*w2'\n", + "x1w1x2w2 = x1w1 + x2w2; x1w1x2w2.label = 'x1*w1 + x2*w2'\n", + "n = x1w1x2w2 + b; n.label = 'n'\n", + "o = n.tanh(); o.label = 'o'\n", + "o.backward()" + ] + }, + { + "cell_type": "code", + "execution_count": 319, + "metadata": {}, + "outputs": [ + { + "data": { + "image/svg+xml": [ + "\n", + "\n", + "\n", + "\n", + "\n", + "\n", + "\n", + "\n", + "\n", + "140307325553680\n", + "\n", + "x1*w1\n", + "\n", + "data -6.0000\n", + "\n", + "grad 0.5000\n", + "\n", + "\n", + "\n", + "140307325551664+\n", + "\n", + "+\n", + "\n", + "\n", + "\n", + "140307325553680->140307325551664+\n", + "\n", + "\n", + "\n", + "\n", + "\n", + "140307325553680*\n", + "\n", + "*\n", + "\n", + "\n", + "\n", + "140307325553680*->140307325553680\n", + "\n", + "\n", + "\n", + "\n", + "\n", + "140306250485280\n", + "\n", + "w1\n", + "\n", + "data -3.0000\n", + "\n", + "grad 1.0000\n", + "\n", + "\n", + "\n", + "140306250485280->140307325553680*\n", + "\n", + "\n", + "\n", + "\n", + "\n", + "140307325551664\n", + "\n", + "x1*w1 + x2*w2\n", + "\n", + "data -6.0000\n", + "\n", + "grad 0.5000\n", + "\n", + "\n", + "\n", + "140307325551424+\n", + "\n", + "+\n", + "\n", + "\n", + "\n", + "140307325551664->140307325551424+\n", + "\n", + "\n", + "\n", + "\n", + "\n", + "140307325551664+->140307325551664\n", + "\n", + "\n", + "\n", + "\n", + "\n", + "140307325553728\n", + "\n", + "o\n", + "\n", + "data 0.7071\n", + "\n", + "grad 1.0000\n", + "\n", + "\n", + "\n", + "140307325553728tanh\n", + "\n", + "tanh\n", + "\n", + "\n", + "\n", + "140307325553728tanh->140307325553728\n", + "\n", + "\n", + "\n", + "\n", + "\n", + "140306786257456\n", + "\n", + "w2\n", + "\n", + "data 1.0000\n", + "\n", + "grad 0.0000\n", + "\n", + "\n", + "\n", + "140307325550800*\n", + "\n", + "*\n", + "\n", + "\n", + "\n", + "140306786257456->140307325550800*\n", + "\n", + "\n", + "\n", + "\n", + "\n", + "140307325550800\n", + "\n", + "x2*w2\n", + "\n", + "data 0.0000\n", + "\n", + "grad 0.5000\n", + "\n", + "\n", + "\n", + "140307325550800->140307325551664+\n", + "\n", + "\n", + "\n", + "\n", + "\n", + "140307325550800*->140307325550800\n", + "\n", + "\n", + "\n", + "\n", + "\n", + "140306250483984\n", + "\n", + "b\n", + "\n", + "data 6.8814\n", + "\n", + "grad 0.5000\n", + "\n", + "\n", + "\n", + "140306250483984->140307325551424+\n", + "\n", + "\n", + "\n", + "\n", + "\n", + "140306786256688\n", + "\n", + "x1\n", + "\n", + "data 2.0000\n", + "\n", + "grad -1.5000\n", + "\n", + "\n", + "\n", + "140306786256688->140307325553680*\n", + "\n", + "\n", + "\n", + "\n", + "\n", + "140307325551424\n", + "\n", + "n\n", + "\n", + "data 0.8814\n", + "\n", + "grad 0.5000\n", + "\n", + "\n", + "\n", + "140307325551424->140307325553728tanh\n", + "\n", + "\n", + "\n", + "\n", + "\n", + "140307325551424+->140307325551424\n", + "\n", + "\n", + "\n", + "\n", + "\n", + "140306786257264\n", + "\n", + "x2\n", + "\n", + "data 0.0000\n", + "\n", + "grad 0.5000\n", + "\n", + "\n", + "\n", + "140306786257264->140307325550800*\n", + "\n", + "\n", + "\n", + "\n", + "\n" + ], + "text/plain": [ + "" + ] + }, + "execution_count": 319, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "draw_dot(o)" + ] + }, + { + "cell_type": "code", + "execution_count": 320, + "metadata": {}, + "outputs": [ + { + "data": { + "image/svg+xml": [ + "\n", + "\n", + "\n", + "\n", + "\n", + "\n", + "\n", + "\n", + "\n", + "140306525507104\n", + "\n", + " \n", + "\n", + "data 4.8284\n", + "\n", + "grad 0.1464\n", + "\n", + "\n", + "\n", + "140307325465216*\n", + "\n", + "*\n", + "\n", + "\n", + "\n", + "140306525507104->140307325465216*\n", + "\n", + "\n", + "\n", + "\n", + "\n", + "140306525507104+\n", + "\n", + "+\n", + "\n", + "\n", + "\n", + "140306525507104+->140306525507104\n", + "\n", + "\n", + "\n", + "\n", + "\n", + "140307325464640\n", + "\n", + " \n", + "\n", + "data 0.1464\n", + "\n", + "grad 4.8284\n", + "\n", + "\n", + "\n", + "140307325464640->140307325465216*\n", + "\n", + "\n", + "\n", + "\n", + "\n", + "140307325464640**-1\n", + "\n", + "**-1\n", + "\n", + "\n", + "\n", + "140307325464640**-1->140307325464640\n", + "\n", + "\n", + "\n", + "\n", + "\n", + "140307057181792\n", + "\n", + "x2*w2\n", + "\n", + "data 0.0000\n", + "\n", + "grad 0.5000\n", + "\n", + "\n", + "\n", + "140307057184240+\n", + "\n", + "+\n", + "\n", + "\n", + "\n", + "140307057181792->140307057184240+\n", + "\n", + "\n", + "\n", + "\n", + "\n", + "140307057181792*\n", + "\n", + "*\n", + "\n", + "\n", + "\n", + "140307057181792*->140307057181792\n", + "\n", + "\n", + "\n", + "\n", + "\n", + "140307325465216\n", + "\n", + "o\n", + "\n", + "data 0.7071\n", + "\n", + "grad 1.0000\n", + "\n", + "\n", + "\n", + "140307325465216*->140307325465216\n", + "\n", + "\n", + "\n", + "\n", + "\n", + "140306786144896\n", + "\n", + "b\n", + "\n", + "data 6.8814\n", + "\n", + "grad 0.5000\n", + "\n", + "\n", + "\n", + "140307057183664+\n", + "\n", + "+\n", + "\n", + "\n", + "\n", + "140306786144896->140307057183664+\n", + "\n", + "\n", + "\n", + "\n", + "\n", + "140307325467808\n", + "\n", + " \n", + "\n", + "data 1.0000\n", + "\n", + "grad -0.1036\n", + "\n", + "\n", + "\n", + "140307325466992+\n", + "\n", + "+\n", + "\n", + "\n", + "\n", + "140307325467808->140307325466992+\n", + "\n", + "\n", + "\n", + "\n", + "\n", + "140306786145184\n", + "\n", + "x1\n", + "\n", + "data 2.0000\n", + "\n", + "grad -1.5000\n", + "\n", + "\n", + "\n", + "140306786146720*\n", + "\n", + "*\n", + "\n", + "\n", + "\n", + "140306786145184->140306786146720*\n", + "\n", + "\n", + "\n", + "\n", + "\n", + "140306786147536\n", + "\n", + "w1\n", + "\n", + "data -3.0000\n", + "\n", + "grad 1.0000\n", + "\n", + "\n", + "\n", + "140306786147536->140306786146720*\n", + "\n", + "\n", + "\n", + "\n", + "\n", + "140306525528800\n", + "\n", + " \n", + "\n", + "data 5.8284\n", + "\n", + "grad 0.0429\n", + "\n", + "\n", + "\n", + "140306525528800->140306525507104+\n", + "\n", + "\n", + "\n", + "\n", + "\n", + "140306525528800->140307325466992+\n", + "\n", + "\n", + "\n", + "\n", + "\n", + "140306525528800exp\n", + "\n", + "exp\n", + "\n", + "\n", + "\n", + "140306525528800exp->140306525528800\n", + "\n", + "\n", + "\n", + "\n", + "\n", + "140306786146720\n", + "\n", + "x1*w1\n", + "\n", + "data -6.0000\n", + "\n", + "grad 0.5000\n", + "\n", + "\n", + "\n", + "140306786146720->140307057184240+\n", + "\n", + "\n", + "\n", + "\n", + "\n", + "140306786146720*->140306786146720\n", + "\n", + "\n", + "\n", + "\n", + "\n", + "140307057180928\n", + "\n", + " \n", + "\n", + "data 2.0000\n", + "\n", + "grad 0.2203\n", + "\n", + "\n", + "\n", + "140307057184096*\n", + "\n", + "*\n", + "\n", + "\n", + "\n", + "140307057180928->140307057184096*\n", + "\n", + "\n", + "\n", + "\n", + "\n", + "140306786145712\n", + "\n", + "x2\n", + "\n", + "data 0.0000\n", + "\n", + "grad 0.5000\n", + "\n", + "\n", + "\n", + "140306786145712->140307057181792*\n", + "\n", + "\n", + "\n", + "\n", + "\n", + "140307057184096\n", + "\n", + " \n", + "\n", + "data 1.7627\n", + "\n", + "grad 0.2500\n", + "\n", + "\n", + "\n", + "140307057184096->140306525528800exp\n", + "\n", + "\n", + "\n", + "\n", + "\n", + "140307057184096*->140307057184096\n", + "\n", + "\n", + "\n", + "\n", + "\n", + "140307325466992\n", + "\n", + " \n", + "\n", + "data 6.8284\n", + "\n", + "grad -0.1036\n", + "\n", + "\n", + "\n", + "140307325466992->140307325464640**-1\n", + "\n", + "\n", + "\n", + "\n", + "\n", + "140307325466992+->140307325466992\n", + "\n", + "\n", + "\n", + "\n", + "\n", + "140306525505952\n", + "\n", + " \n", + "\n", + "data -1.0000\n", + "\n", + "grad 0.1464\n", + "\n", + "\n", + "\n", + "140306525505952->140306525507104+\n", + "\n", + "\n", + "\n", + "\n", + "\n", + "140307057183664\n", + "\n", + "n\n", + "\n", + "data 0.8814\n", + "\n", + "grad 0.5000\n", + "\n", + "\n", + "\n", + "140307057183664->140307057184096*\n", + "\n", + "\n", + "\n", + "\n", + "\n", + "140307057183664+->140307057183664\n", + "\n", + "\n", + "\n", + "\n", + "\n", + "140306786145232\n", + "\n", + "w2\n", + "\n", + "data 1.0000\n", + "\n", + "grad 0.0000\n", + "\n", + "\n", + "\n", + "140306786145232->140307057181792*\n", + "\n", + "\n", + "\n", + "\n", + "\n", + "140307057184240\n", + "\n", + "x1*w1 + x2*w2\n", + "\n", + "data -6.0000\n", + "\n", + "grad 0.5000\n", + "\n", + "\n", + "\n", + "140307057184240->140307057183664+\n", + "\n", + "\n", + "\n", + "\n", + "\n", + "140307057184240+->140307057184240\n", + "\n", + "\n", + "\n", + "\n", + "\n" + ], + "text/plain": [ + "" + ] + }, + "execution_count": 320, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "# inputs x1,x2\n", + "x1 = Value(2.0, label='x1')\n", + "x2 = Value(0.0, label='x2')\n", + "# weights w1,w2\n", + "w1 = Value(-3.0, label='w1')\n", + "w2 = Value(1.0, label='w2')\n", + "# bias of the neuron\n", + "b = Value(6.8813735870195432, label='b')\n", + "# x1*w1 + x2*w2 + b\n", + "x1w1 = x1*w1; x1w1.label = 'x1*w1'\n", + "x2w2 = x2*w2; x2w2.label = 'x2*w2'\n", + "x1w1x2w2 = x1w1 + x2w2; x1w1x2w2.label = 'x1*w1 + x2*w2'\n", + "n = x1w1x2w2 + b; n.label = 'n'\n", + "# ----\n", + "e = (2*n).exp()\n", + "o = (e - 1) / (e + 1)\n", + "# ----\n", + "o.label = 'o'\n", + "o.backward()\n", + "draw_dot(o)" + ] + }, + { + "cell_type": "code", + "execution_count": 369, + "metadata": {}, + "outputs": [], + "source": [ + "import torch" + ] + }, + { + "cell_type": "code", + "execution_count": 376, + "metadata": {}, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "0.7071066904050358\n", + "---\n", + "x2 0.5000001283844369\n", + "w2 0.0\n", + "x1 -1.5000003851533106\n", + "w1 1.0000002567688737\n" + ] + } + ], + "source": [ + "\n", + "x1 = torch.Tensor([2.0]).double() ; x1.requires_grad = True\n", + "x2 = torch.Tensor([0.0]).double() ; x2.requires_grad = True\n", + "w1 = torch.Tensor([-3.0]).double() ; w1.requires_grad = True\n", + "w2 = torch.Tensor([1.0]).double() ; w2.requires_grad = True\n", + "b = torch.Tensor([6.8813735870195432]).double() ; b.requires_grad = True\n", + "n = x1*w1 + x2*w2 + b\n", + "o = torch.tanh(n)\n", + "\n", + "print(o.data.item())\n", + "o.backward()\n", + "\n", + "print('---')\n", + "print('x2', x2.grad.item())\n", + "print('w2', w2.grad.item())\n", + "print('x1', x1.grad.item())\n", + "print('w1', w1.grad.item())" + ] + }, + { + "cell_type": "code", + "execution_count": 592, + "metadata": {}, + "outputs": [], + "source": [ + "\n", + "class Neuron:\n", + " \n", + " def __init__(self, nin):\n", + " self.w = [Value(random.uniform(-1,1)) for _ in range(nin)]\n", + " self.b = Value(random.uniform(-1,1))\n", + " \n", + " def __call__(self, x):\n", + " # w * x + b\n", + " act = sum((wi*xi for wi, xi in zip(self.w, x)), self.b)\n", + " out = act.tanh()\n", + " return out\n", + " \n", + " def parameters(self):\n", + " return self.w + [self.b]\n", + "\n", + "class Layer:\n", + " \n", + " def __init__(self, nin, nout):\n", + " self.neurons = [Neuron(nin) for _ in range(nout)]\n", + " \n", + " def __call__(self, x):\n", + " outs = [n(x) for n in self.neurons]\n", + " return outs[0] if len(outs) == 1 else outs\n", + " \n", + " def parameters(self):\n", + " return [p for neuron in self.neurons for p in neuron.parameters()]\n", + "\n", + "class MLP:\n", + " \n", + " def __init__(self, nin, nouts):\n", + " sz = [nin] + nouts\n", + " self.layers = [Layer(sz[i], sz[i+1]) for i in range(len(nouts))]\n", + " \n", + " def __call__(self, x):\n", + " for layer in self.layers:\n", + " x = layer(x)\n", + " return x\n", + " \n", + " def parameters(self):\n", + " return [p for layer in self.layers for p in layer.parameters()]\n" + ] + }, + { + "cell_type": "code", + "execution_count": 665, + "metadata": {}, + "outputs": [ + { + "data": { + "text/plain": [ + "Value(data=0.16578526021381612)" + ] + }, + "execution_count": 665, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "x = [2.0, 3.0, -1.0]\n", + "n = MLP(3, [4, 4, 1])\n", + "n(x)" + ] + }, + { + "cell_type": "code", + "execution_count": 666, + "metadata": {}, + "outputs": [], + "source": [ + "xs = [\n", + " [2.0, 3.0, -1.0],\n", + " [3.0, -1.0, 0.5],\n", + " [0.5, 1.0, 1.0],\n", + " [1.0, 1.0, -1.0],\n", + "]\n", + "ys = [1.0, -1.0, -1.0, 1.0] # desired targets" + ] + }, + { + "cell_type": "code", + "execution_count": 682, + "metadata": {}, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "0 0.002056123958292787\n", + "1 0.0020404768419831024\n", + "2 0.0020250564320649566\n", + "3 0.002009857924015696\n", + "4 0.001994876646733686\n", + "5 0.0019801080579702683\n", + "6 0.001965547739947282\n", + "7 0.0019511913951512907\n", + "8 0.0019370348422964524\n", + "9 0.0019230740124479978\n", + "10 0.001909304945299319\n", + "11 0.0018957237855951486\n", + "12 0.0018823267796946328\n", + "13 0.0018691102722676993\n", + "14 0.0018560707031189828\n", + "15 0.0018432046041333716\n", + "16 0.0018305085963379896\n", + "17 0.0018179793870754363\n", + "18 0.0018056137672833098\n", + "19 0.0017934086088756394\n" + ] + } + ], + "source": [ + "\n", + "for k in range(20):\n", + " \n", + " # forward pass\n", + " ypred = [n(x) for x in xs]\n", + " loss = sum((yout - ygt)**2 for ygt, yout in zip(ys, ypred))\n", + " \n", + " # backward pass\n", + " for p in n.parameters():\n", + " p.grad = 0.0\n", + " loss.backward()\n", + " \n", + " # update\n", + " for p in n.parameters():\n", + " p.data += -0.1 * p.grad\n", + " \n", + " print(k, loss.data)\n", + " " + ] + }, + { + "cell_type": "code", + "execution_count": 683, + "metadata": {}, + "outputs": [ + { + "data": { + "text/plain": [ + "[Value(data=0.9817830812439714),\n", + " Value(data=-0.9863881624765284),\n", + " Value(data=-0.9766534529377958),\n", + " Value(data=0.9729591216966093)]" + ] + }, + "execution_count": 683, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "ypred" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [] + } + ], + "metadata": { + "kernelspec": { + "display_name": "Python 3", + "language": "python", + "name": "python3" + }, + "language_info": { + "codemirror_mode": { + "name": "ipython", + "version": 3 + }, + "file_extension": ".py", + "mimetype": "text/x-python", + "name": "python", + "nbconvert_exporter": "python", + "pygments_lexer": "ipython3", + "version": "3.8.5" + } + }, + "nbformat": 4, + "nbformat_minor": 4 +}