Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
5 changes: 3 additions & 2 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -115,14 +115,15 @@ then you can create the datasets for training, validation and testing manually,
and pass them as the `trainingDataset`, `validationDataset` and `testingDataset`
parameters.

You can also print the testing results by setting the `printResults` to `true`.
You can also print the testing results by setting the `printTestingResults` to
`true`.

An example can be found below:

```typescript
await trainer.trainAndTest({
data,
printResults: true
printTestingResults: true
});
```

Expand Down
11 changes: 2 additions & 9 deletions packages/tfjs-node-helpers-example/src/app/app.ts
Original file line number Diff line number Diff line change
Expand Up @@ -7,20 +7,13 @@ import { OwnsTheCarFeatureExtractor } from './feature-extractors/owns-the-car';
import { join } from 'node:path';
import data from '../assets/data.json';

const EPOCHS = 1000;
const PATIENCE = 20;
const BATCH_SIZE = 32;

export async function startApplication() {
export async function startApplication(): Promise<void> {
await train();
await predict();
}

async function train(): Promise<void> {
const trainer = new BinaryClassificationTrainer({
batchSize: BATCH_SIZE,
epochs: EPOCHS,
patience: PATIENCE,
hiddenLayers: [
layers.dense({ units: 128, activation: 'mish' }),
layers.dense({ units: 128, activation: 'mish' })
Expand All @@ -35,7 +28,7 @@ async function train(): Promise<void> {

await trainer.trainAndTest({
data,
printResults: true
printTestingResults: true
});

await trainer.save(join(__dirname, './trained_model'));
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -6,29 +6,26 @@ import {
LayersModel,
model,
node,
onesLike,
Optimizer,
Scalar,
SymbolicTensor,
Tensor,
TensorContainer,
tidy,
where,
zerosLike
TensorContainer
} from '@tensorflow/tfjs-node';
import { green, red } from 'chalk';
import { Table } from 'console-table-printer';
import { FeatureExtractor } from '../feature-engineering/feature-extractor';
import { prepareDatasetsForBinaryClassification } from '../feature-engineering/prepare-datasets-for-binary-classification';
import { ConfusionMatrix } from '../testing/confusion-matrix';
import { Metrics } from '../testing/metrics';
import { binarize } from '../utils/binarize';

export type BinaryClassificationTrainerOptions = {
batchSize: number;
epochs: number;
patience: number;
inputFeatureExtractors: Array<FeatureExtractor<any, any>>;
outputFeatureExtractor: FeatureExtractor<any, any>;
batchSize?: number;
epochs?: number;
patience?: number;
inputFeatureExtractors?: Array<FeatureExtractor<any, any>>;
outputFeatureExtractor?: FeatureExtractor<any, any>;
model?: LayersModel;
hiddenLayers?: Array<layers.Layer>;
optimizer?: string | Optimizer;
Expand All @@ -40,60 +37,37 @@ export class BinaryClassificationTrainer {
protected epochs: number;
protected patience: number;
protected tensorBoardLogsDirectory?: string;
protected inputFeatureExtractors: Array<FeatureExtractor<any, any>>;
protected outputFeatureExtractor: FeatureExtractor<any, any>;
protected inputFeatureExtractors?: Array<FeatureExtractor<any, any>>;
protected outputFeatureExtractor?: FeatureExtractor<any, any>;
protected model!: LayersModel;

protected static DEFAULT_BATCH_SIZE: number = 32;
protected static DEFAULT_EPOCHS: number = 1000;
protected static DEFAULT_PATIENCE: number = 20;

constructor(options: BinaryClassificationTrainerOptions) {
this.batchSize = options.batchSize;
this.epochs = options.epochs;
this.patience = options.patience;
this.batchSize = options.batchSize ?? BinaryClassificationTrainer.DEFAULT_BATCH_SIZE;
this.epochs = options.epochs ?? BinaryClassificationTrainer.DEFAULT_EPOCHS;
this.patience = options.patience ?? BinaryClassificationTrainer.DEFAULT_PATIENCE;
this.tensorBoardLogsDirectory = options.tensorBoardLogsDirectory;
this.inputFeatureExtractors = options.inputFeatureExtractors;
this.outputFeatureExtractor = options.outputFeatureExtractor;

if (options.model !== undefined) {
this.model = options.model;
} else {
if (options.hiddenLayers !== undefined && options.inputFeatureExtractors !== undefined) {
const inputLayer = input({ shape: [options.inputFeatureExtractors.length] });
let symbolicTensor = inputLayer;

options.hiddenLayers.forEach((layer) => {
symbolicTensor = layer.apply(symbolicTensor) as SymbolicTensor;
});

const outputLayer = layers
.dense({ units: 1, activation: 'sigmoid' })
.apply(symbolicTensor) as SymbolicTensor;

this.model = model({
inputs: inputLayer,
outputs: outputLayer
});
} else {
throw new Error('hiddenLayers and inputFeaturesCount options are required when the model is not provided!');
}
}

this.model.compile({
optimizer: options.optimizer ?? 'adam',
loss: 'binaryCrossentropy'
});
this.initializeModel(options);
}

public async trainAndTest({
data,
trainingDataset,
validationDataset,
testingDataset,
printResults
printTestingResults
}: {
data?: Array<any>,
trainingDataset?: data.Dataset<TensorContainer>;
validationDataset?: data.Dataset<TensorContainer>;
testingDataset?: data.Dataset<TensorContainer>;
printResults?: boolean;
printTestingResults?: boolean;
}): Promise<{
loss: number;
confusionMatrix: ConfusionMatrix;
Expand All @@ -111,7 +85,15 @@ export class BinaryClassificationTrainer {
callbacks.push(node.tensorBoard(this.tensorBoardLogsDirectory));
}

if (trainingDataset === undefined || validationDataset === undefined || testingDataset === undefined) {
if (
trainingDataset === undefined ||
validationDataset === undefined ||
testingDataset === undefined
) {
if (this.inputFeatureExtractors === undefined || this.outputFeatureExtractor === undefined) {
throw new Error('trainingDataset, validationDataset and testingDataset are required when inputFeatureExtractors and outputFeatureExtractor are not provided!');
}

const datasets = await prepareDatasetsForBinaryClassification({
data: data as Array<any>,
inputFeatureExtractors: this.inputFeatureExtractors,
Expand All @@ -130,19 +112,50 @@ export class BinaryClassificationTrainer {
callbacks
});

return await this.test({ testingDataset, printResults });
return await this.test({ testingDataset, printTestingResults });
}

public async save(path: string): Promise<void> {
await this.model.save(`file://${path}`);
}

private initializeModel(options: BinaryClassificationTrainerOptions): void {
if (options.model !== undefined) {
this.model = options.model;
} else {
if (options.hiddenLayers !== undefined && options.inputFeatureExtractors !== undefined) {
const inputLayer = input({ shape: [options.inputFeatureExtractors.length] });
let symbolicTensor = inputLayer;

for (const layer of options.hiddenLayers) {
symbolicTensor = layer.apply(symbolicTensor) as SymbolicTensor;
}

const outputLayer = layers
.dense({ units: 1, activation: 'sigmoid' })
.apply(symbolicTensor) as SymbolicTensor;

this.model = model({
inputs: inputLayer,
outputs: outputLayer
});
} else {
throw new Error('hiddenLayers and inputFeatureExtractors options are required when the model is not provided!');
}
}

this.model.compile({
optimizer: options.optimizer ?? 'adam',
loss: 'binaryCrossentropy'
});
}

private async test({
testingDataset,
printResults
printTestingResults
}: {
testingDataset: data.Dataset<TensorContainer>;
printResults?: boolean;
printTestingResults?: boolean;
}): Promise<{
loss: number;
confusionMatrix: ConfusionMatrix;
Expand All @@ -151,23 +164,24 @@ export class BinaryClassificationTrainer {
const lossTensor = (await this.model.evaluateDataset(testingDataset as data.Dataset<any>, {})) as Scalar;
const [loss] = await lossTensor.data();

const testingData = (await testingDataset.toArray()) as Array<{
const [testingData] = (await testingDataset.toArray()) as Array<{
xs: Tensor;
ys: Tensor;
}>;
const testXs = testingData[0].xs;
const testYs = testingData[0].ys;

const testXs = testingData.xs;
const testYs = testingData.ys;

const predictions = this.model.predict(testXs) as Tensor;
const binarizedPredictions = this.binarize(predictions);
const binarizedPredictions = binarize(predictions);

const trueValues = (await testYs.data()) as Float32Array;
const predictedValues = (await binarizedPredictions.data()) as Float32Array;
const trueValues = await testYs.data<'float32'>();
const predictedValues = await binarizedPredictions.data<'float32'>();

const confusionMatrix = this.calculateConfusionMatrix(trueValues, predictedValues);
const metrics = this.calculateMetrics(confusionMatrix);

if (printResults) {
if (printTestingResults) {
this.printTestResults(loss, confusionMatrix, metrics);
}

Expand Down Expand Up @@ -301,12 +315,4 @@ export class BinaryClassificationTrainer {

metricsTable.printTable();
}

private binarize(tensor: Tensor, threshold = 0.5): Tensor {
return tidy(() => {
const condition = tensor.greater(threshold);

return where(condition, onesLike(tensor), zerosLike(tensor));
});
}
}
9 changes: 9 additions & 0 deletions packages/tfjs-node-helpers/src/utils/binarize.ts
Original file line number Diff line number Diff line change
@@ -0,0 +1,9 @@
import { onesLike, Tensor, tidy, where, zerosLike } from '@tensorflow/tfjs-node';

export const binarize = (tensor: Tensor, threshold = 0.5): Tensor => tidy(
() => where(
tensor.greater(threshold),
onesLike(tensor),
zerosLike(tensor)
)
);