Skip to content

Commit

Permalink
Added PyTorch example, minor fixes
Browse files Browse the repository at this point in the history
  • Loading branch information
NoahStapp committed Aug 7, 2018
1 parent cb5f76a commit 2f45b38
Show file tree
Hide file tree
Showing 12 changed files with 769 additions and 399 deletions.
351 changes: 0 additions & 351 deletions IntroductionToPlyto.ipynb

This file was deleted.

327 changes: 327 additions & 0 deletions PlytoKeras.ipynb
Original file line number Diff line number Diff line change
@@ -0,0 +1,327 @@
{
"cells": [
{
"cell_type": "markdown",
"metadata": {},
"source": [
"## Introduction to Plyto with PyTorch\n",
"#### Python Machine Learning Visualization Toolkit\n",
"This notebook will demonstrate how to use our example PyTorch loss callback function with Plyto to visualize model loss throughout the training process of a machine learning algorithm, as well as a tutorial on how to create your own callback function"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"<img src='style/icons/machinelearning-blue.svg'> \n",
"toolbar item opens the Plyto model visualizer for this notebook!"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"#### How it works"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"A callback function that takes a or uses Plyto instance as a parameter is called each iteration through the training loop when using PyTorch. A Plyto instance requires an Altair spec to define plots. Below is an example of a simple altair spec to create a line graph of samples versus loss."
]
},
{
"cell_type": "code",
"execution_count": 5,
"metadata": {},
"outputs": [],
"source": [
"from plyto import PlytoAPI\n",
"\n",
"# an array of Altair specs with one plot of samples versus loss\n",
"spec = [\n",
" {\n",
" # specifies an altair spec\n",
" \"$schema\": \"https://vega.github.io/schema/vega-lite/v2.json\",\n",
" \"name\": \"lossGraph\",\n",
" \n",
" #size of the plot\n",
" \"config\": {\n",
" \"view\": {\n",
" \"height\": 300,\n",
" \"width\": 300\n",
" }\n",
" },\n",
" \n",
" # name of the dataset must be \"dataSet\"\n",
" \"data\": {\n",
" \"name\": \"dataSet\"\n",
" },\n",
" \n",
" # visual encodings of the plot\n",
" \"encoding\": {\n",
" \"x\": {\n",
" \"field\": \"samples\",\n",
" \"type\": \"quantitative\"\n",
" },\n",
" \"y\": {\n",
" \"field\": \"loss\",\n",
" \"type\": \"quantitative\"\n",
" }\n",
" },\n",
" \n",
" \"mark\": \"line\"\n",
" }\n",
"]\n",
"\n",
"# pass the spec into Plyto\n",
"plyto_instance = PlytoAPI(spec)"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"#### Running a model\n",
"To demonstrate how Plyto works, we will be looking at the CIFAR-10 tiny image data, which can be loaded from torchvision.datasets"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"collapsed": false
},
"outputs": [],
"source": [
"import torch\n",
"import torchvision\n",
"import torchvision.transforms as transforms\n",
"import torch.nn as nn\n",
"import torch.nn.functional as F\n",
"import torch.optim as optim\n",
"\n",
"\n",
"transform = transforms.Compose(\n",
" [transforms.ToTensor(),\n",
" transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))])\n",
"\n",
"trainset = torchvision.datasets.CIFAR10(root='./data', train=True,\n",
" download=True, transform=transform)\n",
"trainloader = torch.utils.data.DataLoader(trainset, batch_size=4,\n",
" shuffle=True, num_workers=2)\n",
"\n",
"testset = torchvision.datasets.CIFAR10(root='./data', train=False,\n",
" download=True, transform=transform)\n",
"testloader = torch.utils.data.DataLoader(testset, batch_size=4,\n",
" shuffle=False, num_workers=2)\n",
"\n",
"classes = ('plane', 'car', 'bird', 'cat',\n",
" 'deer', 'dog', 'frog', 'horse', 'ship', 'truck')\n",
"\n",
"class Net(nn.Module):\n",
" def __init__(self):\n",
" super(Net, self).__init__()\n",
" self.conv1 = nn.Conv2d(3, 6, 5)\n",
" self.pool = nn.MaxPool2d(2, 2)\n",
" self.conv2 = nn.Conv2d(6, 16, 5)\n",
" self.fc1 = nn.Linear(16 * 5 * 5, 120)\n",
" self.fc2 = nn.Linear(120, 84)\n",
" self.fc3 = nn.Linear(84, 10)\n",
"\n",
" def forward(self, x):\n",
" x = self.pool(F.relu(self.conv1(x)))\n",
" x = self.pool(F.relu(self.conv2(x)))\n",
" x = x.view(-1, 16 * 5 * 5)\n",
" x = F.relu(self.fc1(x))\n",
" x = F.relu(self.fc2(x))\n",
" x = self.fc3(x)\n",
" return x\n",
"\n",
"\n",
"net = Net()\n",
"criterion = nn.CrossEntropyLoss()\n",
"optimizer = optim.SGD(net.parameters(), lr=0.001, momentum=0.9)"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"However you structure your network, simply call the callback function every N iterations through the training loop you want to update the data and open the Plyto model visualizer to see your statistics and plots update."
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"collapsed": false
},
"outputs": [],
"source": [
"from time import time\n",
"from plyto import PytorchLossCallback\n",
"\n",
"callback = PytorchLossCallback(plyto_instance, 20, 12400) # 5 epochs of \n",
" # 12400 mini-batches each\n",
"\n",
"for epoch in range(20): # loop over the dataset multiple times\n",
" callback.update_step_number(epoch + 1) # update the current epoch\n",
"\n",
" running_loss = 0.0\n",
" for i, data in enumerate(trainloader, 0):\n",
" inputs, labels = data\n",
"\n",
" optimizer.zero_grad()\n",
"\n",
" outputs = net(inputs)\n",
" loss = criterion(outputs, labels)\n",
" loss.backward()\n",
" optimizer.step()\n",
"\n",
" running_loss += loss.item()\n",
" if i % 100 == 0 and i != 0: # print every 100 mini-batches\n",
" callback.update_total_progress(100) # update total progress\n",
" callback.update_data(i, running_loss / 100) # update current progress,\n",
" # loss, and send data\n",
" running_loss = 0.0"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"#### Writing your own callback function"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"A callback function for Plyto is a class that takes an instance of PlytoAPI as a parameter. \n",
"\n",
"Within this custom function, you can define functions to execute or update data at specific points in running the network.\n",
"\n",
"For the progress bars in the status bar to work correctly, your callback function must send epochs, sample_amount, total_progress, current_progress, and epoch_number using Plyto. Further, start_time is required for the panel to display the runtime once the model is complete. Below is a base to work off of, only containing these variables for basic functionality and passing no altair spec for plots."
]
},
{
"cell_type": "code",
"execution_count": 6,
"metadata": {},
"outputs": [],
"source": [
"class PytorchBasicCallback:\n",
" \"\"\"\n",
" Create a callback that will track and display training progress\n",
"\n",
" :param steps: number of epochs/steps\n",
"\n",
" :param sample_amount: number of samples/steps per epoch\n",
"\n",
" :param start_time: start of training time, used to calculate runtime\n",
"\n",
" :param plyto: an instance of a PlytoAPI class\n",
" \"\"\"\n",
"\n",
" def __init__(self, plyto_instance, steps=0, sample_amount=0):\n",
" self.total_progress = 0\n",
" self.start_time = time()\n",
" self.plyto = plyto_instance\n",
" self.initalize_plyto(steps, sample_amount)\n",
"\n",
" def initalize_plyto(self, steps, sample_amount):\n",
" \"\"\"\n",
" Initalize the Plyto instance's total steps and step size\n",
" \n",
" :param steps: total number of steps\n",
"\n",
" :param sample_amount: number of samples/batches per step\n",
" \"\"\"\n",
" self.plyto.update_total_steps(steps)\n",
" self.plyto.update_size(sample_amount)\n",
"\n",
" def update_step_number(self, new_step):\n",
" \"\"\"\n",
" Update the current step/epoch\n",
"\n",
" :param new_step: the current step/epoch\n",
" \"\"\"\n",
" self.plyto.update_current_step(new_step)\n",
"\n",
" def update_total_progress(self, progress):\n",
" \"\"\"\n",
" Update the total training progress\n",
"\n",
" :param progress: the amount to increment the total progress by\n",
" \"\"\"\n",
" self.total_progress += progress\n",
"\n",
" def update_data(self, current_progress):\n",
" \"\"\"\n",
" Update progress, total progress, loss, and runtime before sending data to frontend\n",
"\n",
" :param current_progress: the progress of training the current step/epoch\n",
" \n",
" :param loss: the current batch's training loss\n",
" \"\"\"\n",
" self.plyto.update_current_progress(current_progress)\n",
" self.plyto.update_total_progress(self.total_progress)\n",
" self.plyto.update_runtime(time() - self.start_time)\n",
" self.plyto.send_data()"
]
},
{
"cell_type": "code",
"execution_count": 4,
"metadata": {},
"outputs": [],
"source": [
"from time import time\n",
"\n",
"callback = PytorchBasicCallback(plyto_instance, 5, 12400) # 5 epochs of \n",
" # 12400 mini-batches each\n",
"\n",
"for epoch in range(5): # loop over the dataset multiple times\n",
" callback.update_step_number(epoch + 1) # update the current epoch\n",
"\n",
" running_loss = 0.0\n",
" for i, data in enumerate(trainloader, 0):\n",
" inputs, labels = data\n",
"\n",
" optimizer.zero_grad()\n",
"\n",
" outputs = net(inputs)\n",
" loss = criterion(outputs, labels)\n",
" loss.backward()\n",
" optimizer.step()\n",
"\n",
" if i % 100 == 0 and i != 0: # print every 100 mini-batches\n",
" callback.update_total_progress(100) # update total progress\n",
" callback.update_data(i,) # update current progress and send data\n"
]
}
],
"metadata": {
"kernelspec": {
"display_name": "Python 3",
"language": "python",
"name": "python3"
},
"language_info": {
"codemirror_mode": {
"name": "ipython",
"version": 3
},
"file_extension": ".py",
"mimetype": "text/x-python",
"name": "python",
"nbconvert_exporter": "python",
"pygments_lexer": "ipython3",
"version": "3.6.5"
}
},
"nbformat": 4,
"nbformat_minor": 2
}
Loading

0 comments on commit 2f45b38

Please sign in to comment.