Skip to content
This repository was archived by the owner on Aug 15, 2019. It is now read-only.
Merged
Show file tree
Hide file tree
Changes from 28 commits
Commits
Show all changes
41 commits
Select commit Hold shift + click to select a range
c80ccfc
Support saving optimizers
caisq Apr 28, 2019
e68456b
save
caisq Apr 28, 2019
edd0296
Worked on Adadelta
caisq Apr 28, 2019
828c017
Add unit test for RMSPropOptimzier getWeights() and setWeights()
caisq Apr 28, 2019
d00a018
Adjust weight order in RMSPropOptimizer
caisq Apr 28, 2019
d8c7e6f
Add unit tests for AdadeltaOptimizer
caisq Apr 28, 2019
3f84ceb
Add WeightType
caisq Apr 28, 2019
f1c8521
Adjust RMSProp class name to align with TF v2
caisq Apr 29, 2019
02a02d9
Update SGD optimizer
caisq Apr 29, 2019
5699904
save
caisq Apr 29, 2019
9def0e7
Fix unit tests
caisq Apr 29, 2019
0b284d2
Fix Adam
caisq Apr 30, 2019
7a72bfa
Add unit tests for Adam
caisq Apr 30, 2019
602e379
Remove NamedVariable interface
caisq Apr 30, 2019
716c091
Add iterations to saved weights
caisq Apr 30, 2019
0413be8
Worked on Adagrad
caisq Apr 30, 2019
76b9ad7
Revise adagrad class name
caisq Apr 30, 2019
3e1ae22
Use forEach
caisq Apr 30, 2019
7a788a4
Make Optimizer.applyGradients handle null gradient values
caisq Apr 30, 2019
b875594
Working Adamax
caisq Apr 30, 2019
13c0284
save
caisq Apr 30, 2019
ac20c55
Worked on Momentum
caisq May 1, 2019
c8fabbb
Add VariableWithOriginalName
caisq May 1, 2019
a673850
Merge branch 'master' into save-optimizers
caisq May 1, 2019
ef5572d
Fix tests
caisq May 1, 2019
b52d65c
Clean up
caisq May 1, 2019
8bc46ab
Merge branch 'master' into save-optimizers
caisq May 2, 2019
a9543e1
Address some comments
caisq May 2, 2019
8dd4593
Save
caisq May 2, 2019
eef25af
Fix unit tests
caisq May 2, 2019
d9c3972
Merge branch 'master' into save-optimizers
caisq May 2, 2019
5c5746f
WIP
caisq May 2, 2019
f64dd6d
Merge branch 'save-optimizers' of github.com:caisq/deeplearnjs into s…
caisq May 3, 2019
36bd8e0
Make setWeights() async
caisq May 3, 2019
727d51a
Make getWeights() async
caisq May 3, 2019
1e2073f
async saveIterations
caisq May 3, 2019
2a7c14e
extractIterations()
caisq May 3, 2019
9254be7
Fix AdamOptimizer accumulated beta1 and beta2 state
caisq May 3, 2019
0265364
Merge branch 'master' into save-optimizers
caisq May 3, 2019
317b362
Merge branch 'master' into save-optimizers
caisq May 3, 2019
1ec1e14
Merge branch 'master' into save-optimizers
caisq May 3, 2019
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
24 changes: 20 additions & 4 deletions src/gradients.ts
Original file line number Diff line number Diff line change
Expand Up @@ -257,6 +257,12 @@ function valueAndGrads<O extends Tensor>(f: (...args: Tensor[]) => O): (
* @param f The function to execute. f() should return a scalar.
* @param varList The list of variables to compute the gradients with respect
* to. Defaults to all trainable variables.
* @returns An object with the following keys and values:
* - `value`: The value of the function `f`.
* - `grads`: A map from the names of the variables to the gradients.
* If the `varList` argument is provided explicitly and contains a subset of
* non-trainable variables, this map in the return value will contain keys
* that map the names of the non-trainable variables to `null`.
*/
/** @doc {heading: 'Training', subheading: 'Gradients'} */
function variableGrads(f: () => Scalar, varList?: Variable[]):
Expand All @@ -270,21 +276,26 @@ function variableGrads(f: () => Scalar, varList?: Variable[]):
() =>
'The varList passed in variableGrads(f, varList) must be an array ' +
'of variables');
if (varList == null) {

const specifiedVarList = varList != null;
if (!specifiedVarList) {
// Get all of the trainable variables.
varList = [];
for (const varName in ENGINE.registeredVariables) {
varList.push(ENGINE.registeredVariables[varName]);
}
}

const specifiedNonTrainable: Variable[] =
specifiedVarList ? varList.filter(variable => !variable.trainable) : null;

// Prune non-trainable variables.
const originalVarCount = varList.length;
varList = varList.filter(variable => variable.trainable);
util.assert(
varList.length > 0,
() =>
`variableGrads() expects at least one of the input variables to be ` +
`trainable, but none of the ${originalVarCount} variables is ` +
() => `variableGrads() expects at least one of the input variables to ` +
`be trainable, but none of the ${originalVarCount} variables is ` +
`trainable.`);

const allowNoGradients = true;
Expand All @@ -306,6 +317,11 @@ function variableGrads(f: () => Scalar, varList?: Variable[]):
namedGrads[v.name] = grads[i];
}
});
if (specifiedNonTrainable != null) {
// If varList is explicitly provided and contains non-trainable values,
// add them to the returned gradients with `null` values.
specifiedNonTrainable.forEach(v => namedGrads[v.name] = null);
}
return {value, grads: namedGrads};
}

Expand Down
3 changes: 2 additions & 1 deletion src/io/io.ts
Original file line number Diff line number Diff line change
Expand Up @@ -25,7 +25,7 @@ import {browserHTTPRequest, http, isHTTPScheme} from './http';
import {concatenateArrayBuffers, decodeWeights, encodeWeights, getModelArtifactsInfoForJSON} from './io_utils';
import {fromMemory, withSaveHandler} from './passthrough';
import {getLoadHandlers, getSaveHandlers, registerLoadRouter, registerSaveRouter} from './router_registry';
import {IOHandler, LoadHandler, LoadOptions, ModelArtifacts, ModelJSON, ModelStoreManager, OnProgressCallback, SaveConfig, SaveHandler, SaveResult, WeightsManifestConfig, WeightsManifestEntry} from './types';
import {IOHandler, LoadHandler, LoadOptions, ModelArtifacts, ModelJSON, ModelStoreManager, OnProgressCallback, SaveConfig, SaveHandler, SaveResult, WeightsManifestConfig, WeightsManifestEntry, WeightGroup} from './types';
import {loadWeights, weightsLoaderFactory} from './weights_loader';

export {copyModel, listModels, moveModel, removeModel} from './model_management';
Expand Down Expand Up @@ -57,5 +57,6 @@ export {
weightsLoaderFactory,
WeightsManifestConfig,
WeightsManifestEntry,
WeightGroup,
withSaveHandler
};
25 changes: 19 additions & 6 deletions src/io/io_utils.ts
Original file line number Diff line number Diff line change
Expand Up @@ -17,10 +17,11 @@

import {tensor} from '../ops/tensor_ops';
import {Tensor} from '../tensor';
import {NamedTensorMap} from '../tensor_types';
import {NamedTensor, NamedTensorMap} from '../tensor_types';
import {TypedArray} from '../types';
import {sizeFromShape} from '../util';
import {DTYPE_VALUE_SIZE_MAP, ModelArtifacts, ModelArtifactsInfo, WeightsManifestEntry} from './types';

import {DTYPE_VALUE_SIZE_MAP, ModelArtifacts, ModelArtifactsInfo, WeightGroup, WeightsManifestEntry} from './types';

/**
* Encode a map from names to weight values as an ArrayBuffer, along with an
Expand All @@ -31,27 +32,39 @@ import {DTYPE_VALUE_SIZE_MAP, ModelArtifacts, ModelArtifactsInfo, WeightsManifes
* This function is the reverse of `decodeWeights`.
*
* @param tensors A map ("dict") from names to tensors.
* @param group Group to which the weights belong (optional).
* @returns A `Promise` of
* - A flat `ArrayBuffer` with all the binary values of the `Tensor`s
* concatenated.
* - An `Array` of `WeightManifestEntry`s, carrying information including
* tensor names, `dtype`s and shapes.
* @throws Error: on unsupported tensor `dtype`.
*/
export async function encodeWeights(tensors: NamedTensorMap):
export async function encodeWeights(
tensors: NamedTensorMap|NamedTensor[], group?: WeightGroup):
Promise<{data: ArrayBuffer, specs: WeightsManifestEntry[]}> {
// TODO(adarob, cais): Support quantization.
const specs: WeightsManifestEntry[] = [];
const dataPromises: Array<Promise<TypedArray>> = [];
for (const name in tensors) {
const t = tensors[name];

const names: string[] = Array.isArray(tensors) ?
tensors.map(tensor => tensor.name) :
Object.keys(tensors);

for (let i = 0; i < names.length; ++i) {
const name = names[i];
const t = Array.isArray(tensors) ? tensors[i].tensor : tensors[name];
if (t.dtype !== 'float32' && t.dtype !== 'int32' && t.dtype !== 'bool') {
throw new Error(`Unsupported dtype in weight '${name}': ${t.dtype}`);
}
specs.push({name, shape: t.shape, dtype: t.dtype});
const spec: WeightsManifestEntry = {name, shape: t.shape, dtype: t.dtype};
if (group != null) {
spec.group = group;
}
specs.push(spec);
dataPromises.push(t.data());
}

const tensorValues = await Promise.all(dataPromises);
return {data: concatenateTypedArrays(tensorValues), specs};
}
Expand Down
63 changes: 61 additions & 2 deletions src/io/io_utils_test.ts
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,7 @@
import * as tf from '../index';
import {describeWithFlags} from '../jasmine_util';
import {scalar, tensor1d, tensor2d} from '../ops/ops';
import {NamedTensorMap} from '../tensor_types';
import {NamedTensor, NamedTensorMap} from '../tensor_types';
import {expectArraysEqual} from '../test_util';
import {expectArraysClose} from '../test_util';

Expand Down Expand Up @@ -178,7 +178,7 @@ describe('concatenateTypedArrays', () => {
});

describe('encodeWeights', () => {
it('Float32 tensors', async done => {
it('Float32 tensors as NamedTensorMap', async done => {
const tensors: NamedTensorMap = {
x1: tensor2d([[10, 20], [30, 40]]),
x2: scalar(42),
Expand Down Expand Up @@ -220,6 +220,65 @@ describe('encodeWeights', () => {
});
});

it('Float32 tensors as NamedTensor array', async done => {
const tensors: NamedTensor[] = [
{name: 'x1234', tensor: tensor2d([[10, 20], [30, 40]])}, {
name: 'a42',
tensor: scalar(42),
},
{name: 'b41', tensor: tensor1d([-1.3, -3.7, 1.3, 3.7])}
];
tf.io.encodeWeights(tensors)
.then(dataAndSpecs => {
const data = dataAndSpecs.data;
const specs = dataAndSpecs.specs;
expect(data.byteLength).toEqual(4 * (4 + 1 + 4));
expect(new Float32Array(data, 0, 4)).toEqual(new Float32Array([
10, 20, 30, 40
]));
expect(new Float32Array(data, 16, 1)).toEqual(new Float32Array([42]));
expect(new Float32Array(data, 20, 4)).toEqual(new Float32Array([
-1.3, -3.7, 1.3, 3.7
]));
expect(specs).toEqual([
{
name: 'x1234',
dtype: 'float32',
shape: [2, 2],
},
{
name: 'a42',
dtype: 'float32',
shape: [],
},
{
name: 'b41',
dtype: 'float32',
shape: [4],
}
]);
done();
})
.catch(err => {
console.error(err.stack);
});
});

it('Empty NamedTensor array', async done => {
const tensors: NamedTensor[] = [];
tf.io.encodeWeights(tensors)
.then(dataAndSpecs => {
const data = dataAndSpecs.data;
const specs = dataAndSpecs.specs;
expect(data.byteLength).toEqual(0);
expect(specs).toEqual([]);
done();
})
.catch(err => {
console.error(err.stack);
});
});

it('Int32 tensors', async done => {
const tensors: NamedTensorMap = {
x1: tensor2d([[10, 20], [30, 40]], [2, 2], 'int32'),
Expand Down
26 changes: 25 additions & 1 deletion src/io/types.ts
Original file line number Diff line number Diff line change
Expand Up @@ -59,6 +59,13 @@ export declare interface WeightsManifestGroupConfig {
weights: WeightsManifestEntry[];
}

/**
* Group to which the weight belongs.
*
* - 'optimizer': Weight from a stateful optimizer.
*/
export type WeightGroup = 'model'|'optimizer';

/**
* An entry in the weight manifest.
*
Expand All @@ -80,6 +87,16 @@ export declare interface WeightsManifestEntry {
*/
dtype: 'float32'|'int32'|'bool';

/**
* Type of the weight.
*
* Optional.
*
* The value 'optimizer' indicates the weight belongs to an optimizer
* (i.e., used only during model training and not during inference).
*/
group?: WeightGroup;

/**
* Information for dequantization of the weight.
*/
Expand All @@ -97,9 +114,16 @@ export declare interface WeightsManifestEntry {
export interface SaveConfig {
/**
* Whether to save only the trainable weights of the model, ignoring the
* untrainable ones.
* non-trainable ones.
*/
trainableOnly?: boolean;

/**
* Whether the optimizer will be saved (if exists).
*
* Default: `false`.
*/
includeOptimizer?: boolean;
}

/**
Expand Down
96 changes: 66 additions & 30 deletions src/optimizers/adadelta_optimizer.ts
Original file line number Diff line number Diff line change
Expand Up @@ -16,18 +16,18 @@
*/

import {ENGINE} from '../engine';
import {tidy} from '../globals';
import {dispose, tidy} from '../globals';
import {zerosLike} from '../ops/ops';
import {ConfigDict, registerClass, Serializable, SerializableConstructor} from '../serialization';
import {NamedVariableMap} from '../tensor_types';
import {Optimizer} from './optimizer';
import {NamedTensor, NamedVariableMap} from '../tensor_types';
import {Optimizer, OptimizerVariable} from './optimizer';

/** @doclink Optimizer */
export class AdadeltaOptimizer extends Optimizer {
/** @nocollapse */
static className = 'AdadeltaOptimizer';
private accumulatedGrads: NamedVariableMap = {};
private accumulatedUpdates: NamedVariableMap = {};
private accumulatedGrads: OptimizerVariable[] = [];
private accumulatedUpdates: OptimizerVariable[] = [];

constructor(
protected learningRate: number, protected rho: number,
Expand All @@ -39,27 +39,36 @@ export class AdadeltaOptimizer extends Optimizer {
}
}

applyGradients(variableGradients: NamedVariableMap) {
for (const variableName in variableGradients) {
const value = ENGINE.registeredVariables[variableName];
if (this.accumulatedGrads[variableName] == null) {
const trainable = false;
tidy(() => {
this.accumulatedGrads[variableName] =
zerosLike(value).variable(trainable);
});
applyGradients(variableGradients: NamedVariableMap|NamedTensor[]) {
const variableNames = Array.isArray(variableGradients) ?
variableGradients.map(item => item.name) :
Object.keys(variableGradients);

variableNames.forEach((name, i) => {
const value = ENGINE.registeredVariables[name];
const trainable = false;
if (this.accumulatedGrads[i] == null) {
this.accumulatedGrads[i] = {
originalName: `${name}/accum_grad`,
variable: tidy(() => zerosLike(value).variable(trainable))
};
}
if (this.accumulatedUpdates[i] == null) {
this.accumulatedUpdates[i] = {
originalName: `${name}/accum_var`,
variable: tidy(() => zerosLike(value).variable(trainable))
};
}
if (this.accumulatedUpdates[variableName] == null) {
const trainable = false;
tidy(() => {
this.accumulatedUpdates[variableName] =
zerosLike(value).variable(trainable);
});

const gradient = Array.isArray(variableGradients) ?
variableGradients[i].tensor :
variableGradients[name];
if (gradient == null) {
return;
}

const gradient = variableGradients[variableName];
const accumulatedGrad = this.accumulatedGrads[variableName];
const accumulatedUpdate = this.accumulatedUpdates[variableName];
const accumulatedGrad = this.accumulatedGrads[i].variable;
const accumulatedUpdate = this.accumulatedUpdates[i].variable;

tidy(() => {
const newAccumulatedGrad = accumulatedGrad.mul(this.rho).add(
Expand All @@ -73,23 +82,50 @@ export class AdadeltaOptimizer extends Optimizer {
const newAccumulatedUpdate = accumulatedUpdate.mul(this.rho).add(
updates.square().mul(1 - this.rho));

this.accumulatedGrads[variableName].assign(newAccumulatedGrad);
this.accumulatedUpdates[variableName].assign(newAccumulatedUpdate);
accumulatedGrad.assign(newAccumulatedGrad);
accumulatedUpdate.assign(newAccumulatedUpdate);

const newValue = updates.mul(-this.learningRate).add(value);
value.assign(newValue);
});
}
});
this.incrementIterations();
}

dispose(): void {
super.dispose();
if (this.accumulatedUpdates != null) {
Object.keys(this.accumulatedUpdates)
.forEach(name => this.accumulatedUpdates[name].dispose());
Object.keys(this.accumulatedGrads)
.forEach(name => this.accumulatedGrads[name].dispose());
dispose(this.accumulatedGrads.map(v => v.variable));
dispose(this.accumulatedUpdates.map(v => v.variable));
}
}

getWeights(): NamedTensor[] {
// Order matters for Python compatibility.
const variables: OptimizerVariable[] =
[...this.accumulatedGrads, ...this.accumulatedUpdates];
return super.getWeights().concat(
variables.map(v => ({name: v.originalName, tensor: v.variable})));
}

setWeights(weightValues: NamedTensor[]): void {
weightValues = super.setIterations(weightValues);
const variableCount = weightValues.length / 2;
const trainable = false;
this.accumulatedGrads =
weightValues.slice(0, variableCount).map(v => ({
originalName: v.name,
variable: v.tensor.variable(
trainable)
}));
this.accumulatedUpdates =
weightValues.slice(variableCount, variableCount * 2)
.map(v => ({
originalName: v.name,
variable: v.tensor.variable(trainable)
}));
}

getConfig(): ConfigDict {
return {
'learningRate': this.learningRate,
Expand Down
Loading