Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
23 changes: 13 additions & 10 deletions baseball-node/README.md
Original file line number Diff line number Diff line change
@@ -1,15 +1,18 @@
# TensorFlow.js Example: Training Baseball data in Node.js
# TensorFlow.js Example: Training a baseball model in Node.js

### This demo demonstrates using the Node.js bindings for sever-side model training and predictions.
This demo demonstrates how to use the [Node.js bindings](https://github.com/tensorflow/tfjs-node) for TensorFlow.js.

This package contains 3 components:
1. Models and training data for baseball
2. Node.js server for running pitch type model and reporting over socket.io
3. Client for listening to the server and displaying pitch type predictions
It has four parts:
1. Baseball sensor data
2. Two ML models that do classification given the sensor data:
- Model that predicts the type of pitch.
- Model that predicts if there was a strike.
2. Node.js server that trains a model and serves results over a web socket.
3. Web application that displays predictions and training stats.


## Running the Demo
First, prepare the environment:
First, prepare the environment and download the baseball data from MLB:
```sh
yarn && yarn download-data
```
Expand All @@ -26,12 +29,12 @@ In a new shell, start the server:
yarn start-server
```

To perform model only training to see how Node.js works with the two models, run the following:
If you are interested in testing out the training, without running a web server:
```sh
yarn train-pitch-type-model
yarn train-pitch-model
```
```sh
yarn train-strike-zone-model
yarn train-strike-model
```

## Pitch Models
Expand Down
54 changes: 27 additions & 27 deletions baseball-node/package.json
Original file line number Diff line number Diff line change
Expand Up @@ -9,38 +9,38 @@
"download-data": "./scripts/download-data.sh",
"start-client": "webpack-dev-server --content-base src/client",
"start-server": "ts-node src/server/server.ts",
"train-pitch-type-model": "ts-node src/train/train-pitch-type-model.ts",
"train-strike-zone-model": "ts-node src/train/train-strike-zone-model.ts"
"train-pitch-model": "ts-node src/train/pitch-model.ts",
"train-strike-model": "ts-node src/train/strike-model.ts"
},
"license": "Apache-2.0",
"devDependencies": {
"@types/socket.io": "^1.4.33",
"@types/socket.io-client": "^1.4.32",
"@types/uuid": "^3.4.3",
"clang-format": "^1.2.3",
"containers.js": "^0.0.6",
"css-loader": "^0.28.11",
"file-loader": "^1.1.11",
"mkdirp": "^0.5.1",
"node-simple-timer": "^0.0.1",
"ts-loader": "^4.2.0",
"ts-node": "^6.0.2",
"tslint": "^5.9.1",
"typescript": "^2.8.3",
"uuid": "^3.2.1",
"vue-loader": "^14.2.2",
"vue-template-compiler": "^2.5.16",
"webpack": "^4.7.0",
"webpack-cli": "^2.1.2",
"webpack-dev-server": "^3.1.4",
"yalc": "^1.0.0-pre.22"
"@types/socket.io": "~1.4.33",
"@types/socket.io-client": "~1.4.32",
"@types/uuid": "~3.4.3",
"clang-format": "~1.2.3",
"css-loader": "~0.28.11",
"file-loader": "~1.1.11",
"mkdirp": "~0.5.1",
"ts-loader": "~4.2.0",
"ts-node": "~6.0.2",
"tslint": "~5.9.1",
"typescript": "~2.8.3",
"vue-loader": "~14.2.2",
"vue-template-compiler": "~2.5.16",
"webpack": "~4.7.0",
"webpack-cli": "~2.1.2",
"webpack-dev-server": "~3.1.4",
"yalc": "~1.0.0-pre.22"
},
"dependencies": {
"@tensorflow/tfjs": "^0.10.3",
"@tensorflow/tfjs": "~0.10.3",
"@tensorflow/tfjs-node": "link:.yalc/@tensorflow/tfjs-node",
"baseball-pitchfx-types": "^0.0.4",
"socket.io": "^2.1.0",
"vega-embed": "^3.9.0",
"vue": "^2.5.16"
"baseball-pitchfx-types": "~0.0.4",
"socket.io": "~2.1.0",
"vega-embed": "~3.9.0",
"vue": "~2.5.16",
"uuid": "~3.2.1",
"node-simple-timer": "~0.0.1",
"containers.js": "~0.0.6"
}
}
38 changes: 38 additions & 0 deletions baseball-node/src/abstract-pitch-model.ts
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,8 @@

import * as tf from '@tensorflow/tfjs';
import {PitchData, PitchDataBatch, PitchTrainFields} from './pitch-data';
import {pitchFromType} from 'baseball-pitchfx-types';
import {AccuracyPerClass} from './types';

/** Info about progress during training. */
export interface TrainProgress {
Expand Down Expand Up @@ -56,6 +58,42 @@ export abstract class PitchModel {
}
}

/** Computes accuracy per class for the entire training set. */
async evaluate(): Promise<AccuracyPerClass> {
const batches = this.data.pitchBatches();
const correctPerClass: number[] = [];
const countPerClass: number[] = [];
const numClasses = batches[0].labels.shape[1];
for (let i = 0; i < numClasses; i++) {
correctPerClass[i] = 0;
countPerClass[i] = 0;
}

for (let i = 0; i < batches.length; i++) {
const batch = batches[i];
const predictionBatch = this.model.predict(batch.pitches) as tf.Tensor;
const labelIndicesBatch = batch.labels.argMax(1);
const isCorrectBatch =
await labelIndicesBatch.equal(predictionBatch.argMax(1)).data();
const labelBatch = await labelIndicesBatch.data();
for (let i = 0; i < isCorrectBatch.length; i++) {
const labelIndex = labelBatch[i];
const isCorrect = isCorrectBatch[i];
countPerClass[labelIndex]++;
if (isCorrect) {
correctPerClass[labelIndex]++;
}
}
}

// Return a dict that maps a class name to accuracy.
const result: AccuracyPerClass = {};
correctPerClass.forEach((correct, i) => {
result[pitchFromType(i)] = {training: correct / countPerClass[i]};
});
return result;
}

private async trainInternal(batch: PitchDataBatch,
callback: (progress: TrainProgress) => void, log = false) {
await this.model.fit(batch.pitches, batch.labels, {
Expand Down
37 changes: 35 additions & 2 deletions baseball-node/src/client/App.vue
Original file line number Diff line number Diff line change
Expand Up @@ -16,12 +16,13 @@ limitations under the License.
<template>
<div class="container content">
<div id="accuracyCanvas"></div>
<div class="col-lg-4">
<div>
<div class="card" v-if="predictions.length === 0">
<div class="card-body text-center" style="color: #80868b">
Waiting for live pitch data...
</div>
</div>
<div id="table"></div>
</div>

<transition-group name="list-complete">
Expand All @@ -36,7 +37,39 @@ limitations under the License.
<script lang="ts" src="./app.ts"></script>

<style>
.html {
#table .row {
display: flex;
align-items: center;
margin: 5px 0;
}
.label {
text-align: center;
font-family: "Google Sans", sans-serif;
font-size: 24px;
color: #5f6368;
line-height: 24px;
font-weight: 500;
}
#table .label {
margin-right: 20px;
width: 300px;
text-align: right;
}
#table .score {
background-color: #999;
height: 30px;
text-align: right;
line-height: 30px;
color: white;
padding-right: 10px;
box-sizing: border-box;
}

#table .score-container {
border-right: 1px solid black;
}

html, body {
font-family: Roboto, sans-serif;
}
.flip-list-move {
Expand Down
9 changes: 0 additions & 9 deletions baseball-node/src/client/Pitch.vue
Original file line number Diff line number Diff line change
Expand Up @@ -64,13 +64,4 @@ limitations under the License.
display: inline-block;
width: 100%;
}

.label {
text-align: center;
font-family: "Google Sans", sans-serif;
font-size: 24px;
color: #5f6368;
line-height: 24px;
font-weight: 500;
}
</style>
52 changes: 51 additions & 1 deletion baseball-node/src/client/app.ts
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,7 @@ import Vue from 'vue';
import embed from 'vega-embed';
import Pitch from './Pitch.vue';
import {TrainProgress} from '../abstract-pitch-model';
import {AccuracyPerClass} from '../types';

const maxPitches = 6 * 3; // 3 each row.

Expand Down Expand Up @@ -59,6 +60,10 @@ export default Vue.extend({
}
});

socket.on('accuracyPerClass', (accPerClass: AccuracyPerClass) => {
plotAccuracyPerClass(accPerClass);
});

socket.on('prediction_updates', (data: PitchPredictionUpdateMessage[]) => {
data.forEach((update) => {
const index = this.predictionMap.get(update.uuid);
Expand All @@ -80,12 +85,57 @@ export default Vue.extend({
}
});

const MAX_NUM_POINTS = 100;

function subsample(accuracy: Array<{batch: number, accuracy: number}>) {
const skip = Math.max(1, accuracy.length / MAX_NUM_POINTS);
const result: Array<{batch: number, accuracy: number}> = [];
for (let i = 0; i < accuracy.length; i += skip) {
result.push(accuracy[Math.round(i)]);
}
return result;
}

function plotAccuracyPerClass(accPerClass: AccuracyPerClass) {
const table = document.getElementById('table');
table.innerHTML = '';

const BAR_WIDTH_PX = 300;

for (const label in accPerClass) {
// Row.
const rowDiv = document.createElement('div');
rowDiv.className = 'row';
table.appendChild(rowDiv);

// Label.
const labelDiv = document.createElement('div');
labelDiv.innerText = label;
labelDiv.className = 'label';
rowDiv.appendChild(labelDiv);

// Score.
const scoreContainer = document.createElement('div');
scoreContainer.className = 'score-container';
scoreContainer.style.width = BAR_WIDTH_PX + 'px';

const scoreDiv = document.createElement('div');
scoreDiv.className = 'score';
const score = accPerClass[label].training;
scoreDiv.style.width = (score * BAR_WIDTH_PX) + 'px';
scoreDiv.innerHTML = (score * 100).toFixed(1) + '%';

scoreContainer.appendChild(scoreDiv);
rowDiv.appendChild(scoreContainer);
}
}

function plotProgress(progress: TrainProgress) {
accuracy.push({batch: accuracy.length + 1, accuracy: progress.accuracy});
embed(
'#accuracyCanvas', {
'$schema': 'https://vega.github.io/schema/vega-lite/v2.json',
'data': {'values': accuracy},
'data': {'values': subsample(accuracy)},
'width': 260,
'mark': {'type': 'line', 'legend': null, 'orient': 'vertical'},
'encoding': {
Expand Down
12 changes: 9 additions & 3 deletions baseball-node/src/server/server.ts
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,9 @@
import {bindTensorFlowBackend} from '@tensorflow/tfjs-node';
import {PitchTypeModel} from '../pitch-type-model';
import {Socket} from './socket';
import {sleep} from '../utils';

const TIMEOUT_BETWEEN_EPOCHS_MS = 100;

// Enable TFJS-Node backend
bindTensorFlowBackend();
Expand All @@ -28,11 +31,14 @@ const socket = new Socket(pitchModel);
async function run() {
socket.listen();
await pitchModel.train(1, progress => socket.sendProgress(progress));
socket.sendAccuracyPerClass(await pitchModel.evaluate());

setInterval(async () => {
while (true) {
await pitchModel.train(1, progress => socket.sendProgress(progress));
socket.sendAccuracyPerClass(await pitchModel.evaluate());
socket.broadcastUpdatedPredictions();
}, 3000);
await sleep(TIMEOUT_BETWEEN_EPOCHS_MS);
}
}

run();
run();
5 changes: 5 additions & 0 deletions baseball-node/src/server/socket.ts
Original file line number Diff line number Diff line change
Expand Up @@ -26,6 +26,7 @@ import {loadPitchData} from '../pitch-data';
import {PitchTypeModel} from '../pitch-type-model';
import {getRandomInt} from '../utils';
import {TrainProgress} from '../abstract-pitch-model';
import {AccuracyPerClass} from '../types';

const PORT = 8001;
const PITCH_COUNT = 12;
Expand Down Expand Up @@ -78,6 +79,10 @@ export class Socket {
}
}

sendAccuracyPerClass(accPerClass: AccuracyPerClass) {
this.io.emit('accuracyPerClass', accPerClass);
}

sendProgress(progress: TrainProgress) {
this.io.emit('progress', progress);
}
Expand Down
20 changes: 20 additions & 0 deletions baseball-node/src/types.ts
Original file line number Diff line number Diff line change
@@ -0,0 +1,20 @@
/**
* @license
* Copyright 2018 Google LLC. All Rights Reserved.
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
* =============================================================================
*/

export type AccuracyPerClass = {
[label: string]: {training: number, validation?: number};
};
4 changes: 4 additions & 0 deletions baseball-node/src/utils.ts
Original file line number Diff line number Diff line change
Expand Up @@ -31,3 +31,7 @@ export function normalize(value: number, min: number, max: number): number {
export function getRandomInt(max: number) {
return Math.floor(Math.random() * Math.floor(max));
}

export function sleep(ms: number) {
return new Promise(resolve => setTimeout(resolve, ms));
}
Loading