Have you ever woken up one day and thought: My life would be so much better if I could define and train my machine learning models in React. I haven't.
Define, train and visualize training of your ML models in your favorite front-end library (React) backed by your second favorite front-end library (Tensorflow.js).
- Define models in React/JSX
- Stream training data via ES6 Generators
- Turn on training progress visualization with a single flag
- Pause training
Check out the live demo/clone the demo app to get started quicker.
Out of the box training metric visualization!
yarn add react @tensorflow/tfjs tfjsx
OR
npm install react @tensorflow/tfjs tfjsx
import React from 'react';
import ReactDOM from 'react-dom';
import { Train, Model, Dense } from 'tfjsx';
// Define a generator of train data
function* trainDataGenerator() {
yield { x: 1, y: 1 };
yield { x: 4, y: 4 };
yield { x: 8, y: 8 };
}
function MyTrainedModel() {
// Train the model with the training data generator defined above
return (
<Train
trainData={trainDataGenerator}
epochs={15}
batchSize={3}
samples={3}
onTrainEnd={model => model.describe()}
train
display
>
{/* Define the model architecture */}
<Model optimizer='sgd' loss='meanSquaredError'>
<Dense units={1} inputShape={[1]} />
</Model>
</Train>
);
}
ReactDOM.render(<MyTrainedModel />, document.getElementById('app'));
Property Name | Type | Description |
---|---|---|
trainData | function* () | The generator should yield an object with x and y properties corresponding to training data and label. |
validationData | function* () | Same as trainData , but should generate validation data. Will be used to output validation metrics during training. |
epochs | Number | Number of epochs to train the model for |
batchSize | Number | The number of samples to include in each training batch |
samples | Number | Number of expected samples the generator will be able to generate. |
onTrainEnd | function(tf.Model) | Called after the model is done training, the trained model is passed into the callback |
onBatchEnd | function(Object metrics, tf.Model) | Called after each batch is done training, an object with that batch's training metrics along with the current model is passed into the callback. |
train | Bool | Turn on and off training |
display | Bool | Enable or disable graphing of training status |
All valid config properties passed into model.compile
are valid here.
See config.
Similar to Model
, all valid layers have their props passed through as
properties of the config
object in Tensorflow.js.
The following layers are currently available:
Adding new layer types is simple, PRs are always welcome :)
- Model summarization (adding
display
flag toModel
) - Layer activation visualization (adding
display
flag to any layer) - Model evaluation visualizations
- Allow pre-trained models to be used as a layer