forked from jupytercalpoly/plyto
-
Notifications
You must be signed in to change notification settings - Fork 0
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
- Loading branch information
Showing
12 changed files
with
769 additions
and
399 deletions.
There are no files selected for viewing
This file was deleted.
Oops, something went wrong.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,327 @@ | ||
{ | ||
"cells": [ | ||
{ | ||
"cell_type": "markdown", | ||
"metadata": {}, | ||
"source": [ | ||
"## Introduction to Plyto with PyTorch\n", | ||
"#### Python Machine Learning Visualization Toolkit\n", | ||
"This notebook will demonstrate how to use our example PyTorch loss callback function with Plyto to visualize model loss throughout the training process of a machine learning algorithm, as well as a tutorial on how to create your own callback function" | ||
] | ||
}, | ||
{ | ||
"cell_type": "markdown", | ||
"metadata": {}, | ||
"source": [ | ||
"<img src='style/icons/machinelearning-blue.svg'> \n", | ||
"toolbar item opens the Plyto model visualizer for this notebook!" | ||
] | ||
}, | ||
{ | ||
"cell_type": "markdown", | ||
"metadata": {}, | ||
"source": [ | ||
"#### How it works" | ||
] | ||
}, | ||
{ | ||
"cell_type": "markdown", | ||
"metadata": {}, | ||
"source": [ | ||
"A callback function that takes a or uses Plyto instance as a parameter is called each iteration through the training loop when using PyTorch. A Plyto instance requires an Altair spec to define plots. Below is an example of a simple altair spec to create a line graph of samples versus loss." | ||
] | ||
}, | ||
{ | ||
"cell_type": "code", | ||
"execution_count": 5, | ||
"metadata": {}, | ||
"outputs": [], | ||
"source": [ | ||
"from plyto import PlytoAPI\n", | ||
"\n", | ||
"# an array of Altair specs with one plot of samples versus loss\n", | ||
"spec = [\n", | ||
" {\n", | ||
" # specifies an altair spec\n", | ||
" \"$schema\": \"https://vega.github.io/schema/vega-lite/v2.json\",\n", | ||
" \"name\": \"lossGraph\",\n", | ||
" \n", | ||
" #size of the plot\n", | ||
" \"config\": {\n", | ||
" \"view\": {\n", | ||
" \"height\": 300,\n", | ||
" \"width\": 300\n", | ||
" }\n", | ||
" },\n", | ||
" \n", | ||
" # name of the dataset must be \"dataSet\"\n", | ||
" \"data\": {\n", | ||
" \"name\": \"dataSet\"\n", | ||
" },\n", | ||
" \n", | ||
" # visual encodings of the plot\n", | ||
" \"encoding\": {\n", | ||
" \"x\": {\n", | ||
" \"field\": \"samples\",\n", | ||
" \"type\": \"quantitative\"\n", | ||
" },\n", | ||
" \"y\": {\n", | ||
" \"field\": \"loss\",\n", | ||
" \"type\": \"quantitative\"\n", | ||
" }\n", | ||
" },\n", | ||
" \n", | ||
" \"mark\": \"line\"\n", | ||
" }\n", | ||
"]\n", | ||
"\n", | ||
"# pass the spec into Plyto\n", | ||
"plyto_instance = PlytoAPI(spec)" | ||
] | ||
}, | ||
{ | ||
"cell_type": "markdown", | ||
"metadata": {}, | ||
"source": [ | ||
"#### Running a model\n", | ||
"To demonstrate how Plyto works, we will be looking at the CIFAR-10 tiny image data, which can be loaded from torchvision.datasets" | ||
] | ||
}, | ||
{ | ||
"cell_type": "code", | ||
"execution_count": null, | ||
"metadata": { | ||
"collapsed": false | ||
}, | ||
"outputs": [], | ||
"source": [ | ||
"import torch\n", | ||
"import torchvision\n", | ||
"import torchvision.transforms as transforms\n", | ||
"import torch.nn as nn\n", | ||
"import torch.nn.functional as F\n", | ||
"import torch.optim as optim\n", | ||
"\n", | ||
"\n", | ||
"transform = transforms.Compose(\n", | ||
" [transforms.ToTensor(),\n", | ||
" transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))])\n", | ||
"\n", | ||
"trainset = torchvision.datasets.CIFAR10(root='./data', train=True,\n", | ||
" download=True, transform=transform)\n", | ||
"trainloader = torch.utils.data.DataLoader(trainset, batch_size=4,\n", | ||
" shuffle=True, num_workers=2)\n", | ||
"\n", | ||
"testset = torchvision.datasets.CIFAR10(root='./data', train=False,\n", | ||
" download=True, transform=transform)\n", | ||
"testloader = torch.utils.data.DataLoader(testset, batch_size=4,\n", | ||
" shuffle=False, num_workers=2)\n", | ||
"\n", | ||
"classes = ('plane', 'car', 'bird', 'cat',\n", | ||
" 'deer', 'dog', 'frog', 'horse', 'ship', 'truck')\n", | ||
"\n", | ||
"class Net(nn.Module):\n", | ||
" def __init__(self):\n", | ||
" super(Net, self).__init__()\n", | ||
" self.conv1 = nn.Conv2d(3, 6, 5)\n", | ||
" self.pool = nn.MaxPool2d(2, 2)\n", | ||
" self.conv2 = nn.Conv2d(6, 16, 5)\n", | ||
" self.fc1 = nn.Linear(16 * 5 * 5, 120)\n", | ||
" self.fc2 = nn.Linear(120, 84)\n", | ||
" self.fc3 = nn.Linear(84, 10)\n", | ||
"\n", | ||
" def forward(self, x):\n", | ||
" x = self.pool(F.relu(self.conv1(x)))\n", | ||
" x = self.pool(F.relu(self.conv2(x)))\n", | ||
" x = x.view(-1, 16 * 5 * 5)\n", | ||
" x = F.relu(self.fc1(x))\n", | ||
" x = F.relu(self.fc2(x))\n", | ||
" x = self.fc3(x)\n", | ||
" return x\n", | ||
"\n", | ||
"\n", | ||
"net = Net()\n", | ||
"criterion = nn.CrossEntropyLoss()\n", | ||
"optimizer = optim.SGD(net.parameters(), lr=0.001, momentum=0.9)" | ||
] | ||
}, | ||
{ | ||
"cell_type": "markdown", | ||
"metadata": {}, | ||
"source": [ | ||
"However you structure your network, simply call the callback function every N iterations through the training loop you want to update the data and open the Plyto model visualizer to see your statistics and plots update." | ||
] | ||
}, | ||
{ | ||
"cell_type": "code", | ||
"execution_count": null, | ||
"metadata": { | ||
"collapsed": false | ||
}, | ||
"outputs": [], | ||
"source": [ | ||
"from time import time\n", | ||
"from plyto import PytorchLossCallback\n", | ||
"\n", | ||
"callback = PytorchLossCallback(plyto_instance, 20, 12400) # 5 epochs of \n", | ||
" # 12400 mini-batches each\n", | ||
"\n", | ||
"for epoch in range(20): # loop over the dataset multiple times\n", | ||
" callback.update_step_number(epoch + 1) # update the current epoch\n", | ||
"\n", | ||
" running_loss = 0.0\n", | ||
" for i, data in enumerate(trainloader, 0):\n", | ||
" inputs, labels = data\n", | ||
"\n", | ||
" optimizer.zero_grad()\n", | ||
"\n", | ||
" outputs = net(inputs)\n", | ||
" loss = criterion(outputs, labels)\n", | ||
" loss.backward()\n", | ||
" optimizer.step()\n", | ||
"\n", | ||
" running_loss += loss.item()\n", | ||
" if i % 100 == 0 and i != 0: # print every 100 mini-batches\n", | ||
" callback.update_total_progress(100) # update total progress\n", | ||
" callback.update_data(i, running_loss / 100) # update current progress,\n", | ||
" # loss, and send data\n", | ||
" running_loss = 0.0" | ||
] | ||
}, | ||
{ | ||
"cell_type": "markdown", | ||
"metadata": {}, | ||
"source": [ | ||
"#### Writing your own callback function" | ||
] | ||
}, | ||
{ | ||
"cell_type": "markdown", | ||
"metadata": {}, | ||
"source": [ | ||
"A callback function for Plyto is a class that takes an instance of PlytoAPI as a parameter. \n", | ||
"\n", | ||
"Within this custom function, you can define functions to execute or update data at specific points in running the network.\n", | ||
"\n", | ||
"For the progress bars in the status bar to work correctly, your callback function must send epochs, sample_amount, total_progress, current_progress, and epoch_number using Plyto. Further, start_time is required for the panel to display the runtime once the model is complete. Below is a base to work off of, only containing these variables for basic functionality and passing no altair spec for plots." | ||
] | ||
}, | ||
{ | ||
"cell_type": "code", | ||
"execution_count": 6, | ||
"metadata": {}, | ||
"outputs": [], | ||
"source": [ | ||
"class PytorchBasicCallback:\n", | ||
" \"\"\"\n", | ||
" Create a callback that will track and display training progress\n", | ||
"\n", | ||
" :param steps: number of epochs/steps\n", | ||
"\n", | ||
" :param sample_amount: number of samples/steps per epoch\n", | ||
"\n", | ||
" :param start_time: start of training time, used to calculate runtime\n", | ||
"\n", | ||
" :param plyto: an instance of a PlytoAPI class\n", | ||
" \"\"\"\n", | ||
"\n", | ||
" def __init__(self, plyto_instance, steps=0, sample_amount=0):\n", | ||
" self.total_progress = 0\n", | ||
" self.start_time = time()\n", | ||
" self.plyto = plyto_instance\n", | ||
" self.initalize_plyto(steps, sample_amount)\n", | ||
"\n", | ||
" def initalize_plyto(self, steps, sample_amount):\n", | ||
" \"\"\"\n", | ||
" Initalize the Plyto instance's total steps and step size\n", | ||
" \n", | ||
" :param steps: total number of steps\n", | ||
"\n", | ||
" :param sample_amount: number of samples/batches per step\n", | ||
" \"\"\"\n", | ||
" self.plyto.update_total_steps(steps)\n", | ||
" self.plyto.update_size(sample_amount)\n", | ||
"\n", | ||
" def update_step_number(self, new_step):\n", | ||
" \"\"\"\n", | ||
" Update the current step/epoch\n", | ||
"\n", | ||
" :param new_step: the current step/epoch\n", | ||
" \"\"\"\n", | ||
" self.plyto.update_current_step(new_step)\n", | ||
"\n", | ||
" def update_total_progress(self, progress):\n", | ||
" \"\"\"\n", | ||
" Update the total training progress\n", | ||
"\n", | ||
" :param progress: the amount to increment the total progress by\n", | ||
" \"\"\"\n", | ||
" self.total_progress += progress\n", | ||
"\n", | ||
" def update_data(self, current_progress):\n", | ||
" \"\"\"\n", | ||
" Update progress, total progress, loss, and runtime before sending data to frontend\n", | ||
"\n", | ||
" :param current_progress: the progress of training the current step/epoch\n", | ||
" \n", | ||
" :param loss: the current batch's training loss\n", | ||
" \"\"\"\n", | ||
" self.plyto.update_current_progress(current_progress)\n", | ||
" self.plyto.update_total_progress(self.total_progress)\n", | ||
" self.plyto.update_runtime(time() - self.start_time)\n", | ||
" self.plyto.send_data()" | ||
] | ||
}, | ||
{ | ||
"cell_type": "code", | ||
"execution_count": 4, | ||
"metadata": {}, | ||
"outputs": [], | ||
"source": [ | ||
"from time import time\n", | ||
"\n", | ||
"callback = PytorchBasicCallback(plyto_instance, 5, 12400) # 5 epochs of \n", | ||
" # 12400 mini-batches each\n", | ||
"\n", | ||
"for epoch in range(5): # loop over the dataset multiple times\n", | ||
" callback.update_step_number(epoch + 1) # update the current epoch\n", | ||
"\n", | ||
" running_loss = 0.0\n", | ||
" for i, data in enumerate(trainloader, 0):\n", | ||
" inputs, labels = data\n", | ||
"\n", | ||
" optimizer.zero_grad()\n", | ||
"\n", | ||
" outputs = net(inputs)\n", | ||
" loss = criterion(outputs, labels)\n", | ||
" loss.backward()\n", | ||
" optimizer.step()\n", | ||
"\n", | ||
" if i % 100 == 0 and i != 0: # print every 100 mini-batches\n", | ||
" callback.update_total_progress(100) # update total progress\n", | ||
" callback.update_data(i,) # update current progress and send data\n" | ||
] | ||
} | ||
], | ||
"metadata": { | ||
"kernelspec": { | ||
"display_name": "Python 3", | ||
"language": "python", | ||
"name": "python3" | ||
}, | ||
"language_info": { | ||
"codemirror_mode": { | ||
"name": "ipython", | ||
"version": 3 | ||
}, | ||
"file_extension": ".py", | ||
"mimetype": "text/x-python", | ||
"name": "python", | ||
"nbconvert_exporter": "python", | ||
"pygments_lexer": "ipython3", | ||
"version": "3.6.5" | ||
} | ||
}, | ||
"nbformat": 4, | ||
"nbformat_minor": 2 | ||
} |
Oops, something went wrong.