Skip to content
Open
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
34 changes: 21 additions & 13 deletions src/StyleTransfer/index.js
Original file line number Diff line number Diff line change
Expand Up @@ -3,8 +3,6 @@
// This software is released under the MIT License.
// https://opensource.org/licenses/MIT

/* eslint max-len: "off" */
/* eslint no-trailing-spaces: "off" */
/*
Fast Style Transfer
This implementation is heavily based on github.com/reiinakano/fast-style-transfer-deeplearnjs by Reiichiro Nakano.
Expand Down Expand Up @@ -50,7 +48,6 @@ class StyleTransfer extends Video {
this.timesScalar = tf.scalar(150);
this.plusScalar = tf.scalar(255.0 / 2);
this.epsilonScalar = tf.scalar(1e-3);
this.video = null;
this.ready = callCallback(this.load(model), callback);
// this.then = this.ready.then;
}
Expand Down Expand Up @@ -90,8 +87,8 @@ class StyleTransfer extends Video {
const moments = tf.moments(input, [0, 1]);
const mu = moments.mean;
const sigmaSq = moments.variance;
const shift = this.variables[StyleTransfer.getVariableName(id)];
const scale = this.variables[StyleTransfer.getVariableName(id + 1)];
const shift = this.getVariable(id);
const scale = this.getVariable(id + 1);
const epsilon = this.epsilonScalar;
const normalized = tf.div(tf.sub(input.asType('float32'), mu), tf.sqrt(tf.add(sigmaSq, epsilon)));
const shifted = tf.add(tf.mul(scale, normalized), shift);
Expand All @@ -109,7 +106,7 @@ class StyleTransfer extends Video {
*/
convLayer(input, strides, relu, id) {
return tf.tidy(() => {
const y = tf.conv2d(input, this.variables[StyleTransfer.getVariableName(id)], [strides, strides], 'same');
const y = tf.conv2d(input, this.getVariable(id), [strides, strides], 'same');
const y2 = this.instanceNorm(y, id + 1);
return relu ? tf.relu(y2) : y2;
});
Expand Down Expand Up @@ -142,7 +139,7 @@ class StyleTransfer extends Video {
const newRows = height * strides;
const newCols = width * strides;
const newShape = [newRows, newCols, numFilters];
const y = tf.conv2dTranspose(input, this.variables[StyleTransfer.getVariableName(id)], newShape, [strides, strides], 'same');
const y = tf.conv2dTranspose(input, this.getVariable(id), newShape, [strides, strides], 'same');
const y2 = this.instanceNorm(y, id + 1);
const y3 = tf.relu(y2);
return y3;
Expand Down Expand Up @@ -180,6 +177,10 @@ class StyleTransfer extends Video {

/**
* @private
*
* Applies each layer of the model in sequence, where the output of one layer
* is used as the input of the next layer.
*
* @param {ImageData | HTMLImageElement | HTMLCanvasElement | HTMLVideoElement} input
* @return {Promise<HTMLImageElement>}
*/
Expand Down Expand Up @@ -221,12 +222,19 @@ class StyleTransfer extends Video {
this.epsilonScalar.dispose();
}

// Static Methods
static getVariableName(id) {
if (id === 0) {
return 'Variable';
}
return `Variable_${id}`;
/**
* @private
*
* Access a variable's tensor from its numeric index.
* Model contains variables with ids from 0 to 47.
* The returned tensor will be 4D if `id` is divisible by 3, or 1D otherwise.
*
* @param {number} id
* @returns {tf.Tensor}
*/
getVariable(id) {
const key = id === 0 ? 'Variable' : `Variable_${id}`;
return this.variables[key];
}
}

Expand Down