Skip to content

Commit f8e4b8f

Browse files
committed
move layer creation and other task-dependent logic into interchangeable task objects
1 parent 834e22f commit f8e4b8f

File tree

5 files changed

+316
-233
lines changed

5 files changed

+316
-233
lines changed

src/NeuralNetwork/NeuralNetwork.js

Lines changed: 4 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -20,7 +20,6 @@ class NeuralNetwork {
2020
this.createModel = this.createModel.bind(this);
2121
this.addLayer = this.addLayer.bind(this);
2222
this.compile = this.compile.bind(this);
23-
this.setOptimizerFunction = this.setOptimizerFunction.bind(this);
2423
this.train = this.train.bind(this);
2524
this.predict = this.predict.bind(this);
2625
this.classify = this.classify.bind(this);
@@ -57,11 +56,10 @@ class NeuralNetwork {
5756
/**
5857
* add layer to the model
5958
* if the model has 2 or more layers switch the isLayered flag
60-
* @param {*} _layerOptions
59+
* @param {tf.layers.Layer} layer
6160
*/
62-
addLayer(_layerOptions) {
63-
const LAYER_OPTIONS = _layerOptions || {};
64-
this.model.add(LAYER_OPTIONS);
61+
addLayer(layer) {
62+
this.model.add(layer);
6563

6664
// check if it has at least an input and output layer
6765
if (this.model.layers.length >= 2) {
@@ -72,23 +70,13 @@ class NeuralNetwork {
7270
/**
7371
* Compile the model
7472
* if the model is compiled, set the isCompiled flag to true
75-
* @param {*} _modelOptions
73+
* @param {tf.ModelCompileArgs} _modelOptions
7674
*/
7775
compile(_modelOptions) {
7876
this.model.compile(_modelOptions);
7977
this.isCompiled = true;
8078
}
8179

82-
/**
83-
* Set the optimizer function given the learning rate
84-
* as a parameter
85-
* @param {*} learningRate
86-
* @param {*} optimizer
87-
*/
88-
setOptimizerFunction(learningRate, optimizer) {
89-
return optimizer.call(this, learningRate);
90-
}
91-
9280
/**
9381
* Train the model
9482
* @param {*} _options

src/NeuralNetwork/getTask.ts

Lines changed: 180 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,180 @@
1+
import * as tf from '@tensorflow/tfjs';
2+
import { LayerJson, NeuralNetworkOptions } from "./types";
3+
4+
/**
5+
* Separate all task-dependent logic into separate task objects to minimize if/else behavior
6+
* in the main Neural Network class and make it easier to potentially add more tasks in the future.
7+
* May want these to be classes which get the nn instance in the constructor.
8+
*/
9+
10+
export type TaskName = 'classification' | 'regression' | 'imageClassification';
11+
12+
export interface NNTask {
13+
name: TaskName;
14+
15+
// Can optionally override the standard defaults with custom defaults
16+
getDefaultOptions?(): Partial<NeuralNetworkOptions>;
17+
18+
// Note: learningRate is always the first arg of the optimizer, but some optimizers support other optional args as well
19+
getCompileOptions(learningRate: number): tf.ModelCompileArgs;
20+
21+
createLayers(inputShape: tf.Shape, hiddenUnits: number, outputUnits: number): LayerJson[];
22+
23+
getSampleData(inputs: number | string[] | number[], outputs: number | string[]): { xs: number[], ys: (string | number)[] }[]
24+
25+
// TODO: parseInputs and parseOutputs
26+
}
27+
28+
// TODO: move elsewhere
29+
function isStringArray(value: any): value is string[] {
30+
return Array.isArray(value) && value.some(v => typeof v === 'string');
31+
}
32+
33+
// Handling of input sample is the same for all tasks.
34+
function getSampleInput(inputs: number | string[] | number[]): number[] {
35+
if (isStringArray(inputs)) {
36+
throw new Error(`'inputs' cannot be an array of property names when using option 'noTraining'. You must specify the number of inputs.`);
37+
}
38+
const inputSize = Array.isArray(inputs) ? inputs.reduce((a, b) => a * b) : inputs;
39+
return new Array(inputSize).fill(0);
40+
}
41+
42+
const classificationTask: NNTask = {
43+
name: 'classification',
44+
getCompileOptions(learningRate) {
45+
return {
46+
loss: 'categoricalCrossentropy',
47+
optimizer: tf.train.sgd(learningRate),
48+
metrics: ['accuracy'],
49+
}
50+
},
51+
createLayers(inputShape, hiddenUnits, outputUnits) {
52+
return [
53+
{
54+
type: 'dense',
55+
units: hiddenUnits,
56+
activation: 'relu',
57+
inputShape
58+
},
59+
{
60+
type: 'dense',
61+
activation: 'softmax',
62+
units: outputUnits,
63+
},
64+
];
65+
},
66+
getSampleData(inputs, outputs) {
67+
if (!isStringArray(outputs)) {
68+
throw new Error(`Invalid outputs ${outputs}. Outputs must be an array of label names when using option 'noTraining' with task 'classification'.`);
69+
}
70+
const xs = getSampleInput(inputs);
71+
return outputs.map(label => ({ xs, ys: [label] }));
72+
}
73+
}
74+
75+
const imageClassificationTask: NNTask = {
76+
name: 'imageClassification',
77+
getDefaultOptions() {
78+
return {
79+
learningRate: 0.02
80+
}
81+
},
82+
getCompileOptions: classificationTask.getCompileOptions,
83+
createLayers(inputShape, hiddenUnits, outputUnits) {
84+
return [
85+
{
86+
type: 'conv2d',
87+
filters: 8,
88+
kernelSize: 5,
89+
strides: 1,
90+
activation: 'relu',
91+
kernelInitializer: 'varianceScaling',
92+
inputShape,
93+
},
94+
{
95+
type: 'maxPooling2d',
96+
poolSize: [2, 2],
97+
strides: [2, 2],
98+
},
99+
{
100+
type: 'conv2d',
101+
filters: 16,
102+
kernelSize: 5,
103+
strides: 1,
104+
activation: 'relu',
105+
kernelInitializer: 'varianceScaling',
106+
},
107+
{
108+
type: 'maxPooling2d',
109+
poolSize: [2, 2],
110+
strides: [2, 2],
111+
},
112+
{
113+
type: 'flatten',
114+
},
115+
{
116+
type: 'dense',
117+
kernelInitializer: 'varianceScaling',
118+
activation: 'softmax',
119+
units: outputUnits,
120+
},
121+
];
122+
},
123+
getSampleData: classificationTask.getSampleData
124+
}
125+
126+
const regressionTask: NNTask = {
127+
name: 'regression',
128+
getCompileOptions(learningRate) {
129+
return {
130+
loss: 'meanSquaredError',
131+
optimizer: tf.train.adam(learningRate),
132+
metrics: ['accuracy'],
133+
};
134+
},
135+
createLayers(inputShape, hiddenUnits, outputUnits) {
136+
return [
137+
{
138+
type: 'dense',
139+
units: hiddenUnits,
140+
activation: 'relu',
141+
inputShape
142+
},
143+
{
144+
type: 'dense',
145+
activation: 'sigmoid',
146+
units: outputUnits,
147+
},
148+
];
149+
},
150+
getSampleData(inputs, outputs) {
151+
if (typeof outputs !== 'number') {
152+
throw new Error(`Invalid outputs ${outputs}. Outputs must be a number when using option 'noTraining' with task 'regression'.`);
153+
}
154+
return [{
155+
xs: getSampleInput(inputs),
156+
ys: new Array(outputs).fill(0)
157+
}]
158+
}
159+
}
160+
161+
/**
162+
* Mapping of supported task configurations and their task names.
163+
* Use lowercase keys to make the lookup case-insensitive.
164+
*/
165+
const TASKS: Record<Lowercase<TaskName>, NNTask> = {
166+
regression: regressionTask,
167+
classification: classificationTask,
168+
imageclassification: imageClassificationTask,
169+
}
170+
171+
/**
172+
* Get the correct task object based on the task name.
173+
*/
174+
export default function getTask(name: TaskName | string): NNTask {
175+
const task = TASKS[name.toLowerCase()];
176+
if (!task) {
177+
throw new Error(`Unknown task name '${name}'. Task must be one of ${Object.keys(TASKS).join(', ')}`);
178+
}
179+
return task;
180+
}

0 commit comments

Comments
 (0)