-
Notifications
You must be signed in to change notification settings - Fork 1.9k
/
training.ts
2208 lines (2064 loc) · 83.2 KB
/
training.ts
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
419
420
421
422
423
424
425
426
427
428
429
430
431
432
433
434
435
436
437
438
439
440
441
442
443
444
445
446
447
448
449
450
451
452
453
454
455
456
457
458
459
460
461
462
463
464
465
466
467
468
469
470
471
472
473
474
475
476
477
478
479
480
481
482
483
484
485
486
487
488
489
490
491
492
493
494
495
496
497
498
499
500
501
502
503
504
505
506
507
508
509
510
511
512
513
514
515
516
517
518
519
520
521
522
523
524
525
526
527
528
529
530
531
532
533
534
535
536
537
538
539
540
541
542
543
544
545
546
547
548
549
550
551
552
553
554
555
556
557
558
559
560
561
562
563
564
565
566
567
568
569
570
571
572
573
574
575
576
577
578
579
580
581
582
583
584
585
586
587
588
589
590
591
592
593
594
595
596
597
598
599
600
601
602
603
604
605
606
607
608
609
610
611
612
613
614
615
616
617
618
619
620
621
622
623
624
625
626
627
628
629
630
631
632
633
634
635
636
637
638
639
640
641
642
643
644
645
646
647
648
649
650
651
652
653
654
655
656
657
658
659
660
661
662
663
664
665
666
667
668
669
670
671
672
673
674
675
676
677
678
679
680
681
682
683
684
685
686
687
688
689
690
691
692
693
694
695
696
697
698
699
700
701
702
703
704
705
706
707
708
709
710
711
712
713
714
715
716
717
718
719
720
721
722
723
724
725
726
727
728
729
730
731
732
733
734
735
736
737
738
739
740
741
742
743
744
745
746
747
748
749
750
751
752
753
754
755
756
757
758
759
760
761
762
763
764
765
766
767
768
769
770
771
772
773
774
775
776
777
778
779
780
781
782
783
784
785
786
787
788
789
790
791
792
793
794
795
796
797
798
799
800
801
802
803
804
805
806
807
808
809
810
811
812
813
814
815
816
817
818
819
820
821
822
823
824
825
826
827
828
829
830
831
832
833
834
835
836
837
838
839
840
841
842
843
844
845
846
847
848
849
850
851
852
853
854
855
856
857
858
859
860
861
862
863
864
865
866
867
868
869
870
871
872
873
874
875
876
877
878
879
880
881
882
883
884
885
886
887
888
889
890
891
892
893
894
895
896
897
898
899
900
901
902
903
904
905
906
907
908
909
910
911
912
913
914
915
916
917
918
919
920
921
922
923
924
925
926
927
928
929
930
931
932
933
934
935
936
937
938
939
940
941
942
943
944
945
946
947
948
949
950
951
952
953
954
955
956
957
958
959
960
961
962
963
964
965
966
967
968
969
970
971
972
973
974
975
976
977
978
979
980
981
982
983
984
985
986
987
988
989
990
991
992
993
994
995
996
997
998
999
1000
/**
* @license
* Copyright 2018 Google LLC
*
* Use of this source code is governed by an MIT-style
* license that can be found in the LICENSE file or at
* https://opensource.org/licenses/MIT.
* =============================================================================
*/
/* Original Source: engine/training.py */
import * as tfc from '@tensorflow/tfjs-core';
import {io, ModelPredictConfig as ModelPredictArgs, NamedTensorMap, Optimizer, Scalar, scalar, serialization, Tensor, Tensor1D, tensor1d, util} from '@tensorflow/tfjs-core';
import * as K from '../backend/tfjs_backend';
import {BaseCallback, configureCallbacks, History, ModelLoggingVerbosity, standardizeCallbacks} from '../base_callbacks';
import {nameScope} from '../common';
import {NotImplementedError, RuntimeError, ValueError} from '../errors';
import {Shape} from '../keras_format/common';
import {LossIdentifier} from '../keras_format/loss_config';
import {OptimizerSerialization} from '../keras_format/optimizer_config';
import {MetricsIdentifier, TrainingConfig} from '../keras_format/training_config';
import {deserialize} from '../layers/serialization';
import { disposeTensorsInLogs, UnresolvedLogs } from '../logs';
import * as losses from '../losses';
import * as Metrics from '../metrics';
import * as optimizers from '../optimizers';
import {LossOrMetricFn, NamedTensor} from '../types';
import {checkUserDefinedMetadata} from '../user_defined_metadata';
import {count, pyListRepeat, singletonOrArray, toCamelCase, toSnakeCase, unique} from '../utils/generic_utils';
import {printSummary} from '../utils/layer_utils';
import {range} from '../utils/math_utils';
import {convertPythonicToTs} from '../utils/serialization_utils';
import {LayerVariable} from '../variables';
import {version} from '../version';
import {Container, ContainerArgs} from './container';
import {Dataset} from './dataset_stub';
import {execute, FeedDict} from './executor';
import {DisposeResult, SymbolicTensor} from './topology';
import {evaluateDataset, fitDataset, ModelEvaluateDatasetArgs, ModelFitDatasetArgs} from './training_dataset';
import {checkBatchSize, disposeNewTensors, ensureTensorsRank2OrHigher, makeBatches, ModelFitArgs, sliceArrays, sliceArraysByIndices} from './training_tensors';
import {ClassWeight, ClassWeightMap, computeWeightedLoss, standardizeClassWeights, standardizeWeights} from './training_utils';
/**
* Helper function for polymorphic input data: 1. singleton Tensor.
*/
export function isDataTensor(x: Tensor|Tensor[]|{[inputName: string]: Tensor}|
{[inputName: string]: Tensor[]}): boolean {
return x instanceof Tensor;
}
/**
* Helper function for polymorphic input data: 2. Array of Tensor.
*/
export function isDataArray(x: Tensor|Tensor[]|
{[inputName: string]: Tensor}): boolean {
return Array.isArray(x);
}
/**
* Helper function for polymorphic input data: 3. "dict" of Tensor.
*/
export function isDataDict(x: Tensor|Tensor[]|
{[inputName: string]: Tensor}): boolean {
return !isDataTensor(x) && !isDataArray(x);
}
/**
* Normalizes inputs and targets provided by users.
* @param data User-provided input data (polymorphic).
* @param names An Array of expected Tensor names.
* @param shapes Optional Array of expected Tensor shapes.
* @param checkBatchAxis Whether to check that the batch axis of the arrays
* match the expected value found in `shapes`.
* @param exceptionPrefix String prefix used for exception formatting.
* @returns List of standardized input Tensors (one Tensor per model input).
* @throws ValueError: in case of improperly formatted user data.
*/
export function standardizeInputData(
data: Tensor|Tensor[]|{[inputName: string]: Tensor}, names: string[],
shapes?: Shape[], checkBatchAxis = true, exceptionPrefix = ''): Tensor[] {
if (names == null || names.length === 0) {
// Check for the case where the model expected no data, but some data got
// sent.
if (data != null) {
let gotUnexpectedData = false;
if (isDataArray(data) && (data as Tensor[]).length > 0) {
gotUnexpectedData = true;
} else if (isDataDict(data)) {
for (const key in data) {
if (data.hasOwnProperty(key)) {
gotUnexpectedData = true;
break;
}
}
} else {
// `data` is a singleton Tensor in this case.
gotUnexpectedData = true;
}
if (gotUnexpectedData) {
throw new ValueError(
`Error when checking model ${exceptionPrefix} expected no data, ` +
`but got ${data}`);
}
}
return [];
}
if (data == null) {
return names.map(name => null);
}
let arrays: Tensor[];
if (isDataDict(data)) {
data = data as {[inputName: string]: Tensor};
arrays = [];
for (const name of names) {
if (data[name] == null) {
throw new ValueError(
`No data provided for "${name}". Need data for each key in: ` +
`${names}`);
}
arrays.push(data[name]);
}
} else if (isDataArray(data)) {
data = data as Tensor[];
if (data.length !== names.length) {
throw new ValueError(
`Error when checking model ${exceptionPrefix}: the Array of ` +
`Tensors that you are passing to your model is not the size the ` +
`model expected. Expected to see ${names.length} Tensor(s), but ` +
`instead got the following list of Tensor(s): ${data}`);
}
arrays = data;
} else {
data = data as Tensor;
if (names.length > 1) {
throw new ValueError(
`The model ${exceptionPrefix} expects ${names.length} Tensor(s), ` +
`but only received one Tensor. Found: Tensor with shape ${
data.shape}`);
}
arrays = [data];
}
arrays = ensureTensorsRank2OrHigher(arrays);
// Check shape compatibility.
if (shapes != null) {
for (let i = 0; i < names.length; ++i) {
if (shapes[i] == null) {
continue;
}
const array = arrays[i];
if (array.shape.length !== shapes[i].length) {
throw new ValueError(
`Error when checking ${exceptionPrefix}: expected ${names[i]} ` +
`to have ${shapes[i].length} dimension(s). but got array with ` +
`shape ${array.shape}`);
}
for (let j = 0; j < shapes[i].length; ++j) {
if (j === 0 && !checkBatchAxis) {
// Skip the first (batch) axis.
continue;
}
const dim = array.shape[j];
const refDim = shapes[i][j];
if (refDim != null && refDim >= 0 && dim !== refDim) {
throw new ValueError(
`${exceptionPrefix} expected a batch of elements where each ` +
`example has shape [${shapes[i].slice(1, shapes[i].length)}] ` +
`(i.e.,tensor shape [*,${
shapes[i].slice(1, shapes[i].length)}])` +
` but the ${exceptionPrefix} received an input with ${
array.shape[0]}` +
` examples, each with shape [${
array.shape.slice(1, array.shape.length)}]` +
` (tensor shape [${array.shape}])`);
}
}
}
}
return arrays;
}
/**
* User input validation for Tensors.
* @param inputs `Array` of `tf.Tensor`s for inputs.
* @param targets `Array` of `tf.Tensor`s for targets.
* @param weights Optional `Array` of `tf.Tensor`s for sample weights.
* @throws ValueError: in case of incorrectly formatted data.
*/
export function checkArrayLengths(
inputs: Tensor[], targets: Tensor[], weights?: Tensor[]) {
const setX = unique(inputs.map(input => input.shape[0]));
setX.sort();
const setY = unique(targets.map(target => target.shape[0]));
setY.sort();
// TODO(cais): Check `weights` as well.
if (setX.length > 1) {
throw new ValueError(
`All input Tensors (x) should have the same number of samples. ` +
`Got array shapes: ` +
`${JSON.stringify(inputs.map(input => input.shape))}`);
}
if (setY.length > 1) {
throw new ValueError(
`All target Tensors (y) should have the same number of samples. ` +
`Got array shapes: ` +
`${JSON.stringify(targets.map(target => target.shape))}`);
}
if (setX.length > 0 && setY.length > 0 && !util.arraysEqual(setX, setY)) {
throw new ValueError(
`Input Tensors should have the same number of samples as target ` +
`Tensors. Found ${setX[0]} input sample(s) and ${setY[0]} target ` +
`sample(s).`);
}
}
/**
* Validation on the compatibility of targes and loss functions.
*
* This helps prevent users from using loss functions incorrectly.
*
* @param targets `Array` of `tf.Tensor`s of targets.
* @param lossFns `Array` of loss functions.
* @param outputShapes `Array` of shapes of model outputs.
*/
function checkLossAndTargetCompatibility(
targets: Tensor[], lossFns: LossOrMetricFn[], outputShapes: Shape[]) {
// TODO(cais): Dedicated test coverage?
const keyLosses = [
losses.meanSquaredError, losses.binaryCrossentropy,
losses.categoricalCrossentropy
];
for (let i = 0; i < targets.length; ++i) {
const y = targets[i];
const loss = lossFns[i];
const shape = outputShapes[i];
if (loss == null) {
continue;
}
if (loss === losses.categoricalCrossentropy) {
if (y.shape[y.shape.length - 1] === 1) {
throw new ValueError(
`You are passing a target array of shape ${y.shape} while using ` +
`a loss 'categorical_crossentropy'. 'categorical_crossentropy'` +
`expects targets to be binary matrices (1s and 0s) of shape ` +
`[samples, classes].`);
// TODO(cais): Example code in error message.
}
}
if (keyLosses.indexOf(loss) !== -1) {
const slicedYShape = y.shape.slice(1);
const slicedShape = shape.slice(1);
for (let j = 0; j < slicedYShape.length; ++j) {
const targetDim = slicedYShape[j];
const outDim = slicedShape[j];
if (outDim != null && targetDim !== outDim) {
throw new ValueError(
`A target Tensor with shape ${y.shape} was passed for an ` +
`output of shape ${shape}, while using a loss function that ` +
`expects targets to have the same shape as the output.`);
}
}
}
}
}
/**
* Check inputs provided by the user.
*
* Porting Note: This corresponds to _standardize_input_data() in Python
* Keras. Because of the strong typing in TF.js, we do not need to convert
* the data. Specifically:
* 1) in PyKeras, `data` can be `DataFrame` instances from pandas, for
* example. We don't need to worry about that here because there is no
* widely popular javascript/typesdcript equivalent of pandas (so far).
* If one becomes available in the future, we can add support.
* 2) in PyKeras, inputs can be Python dict. But here we are stipulating
* that the data is either a single `tf.Tensor` or an Array of `tf.Tensor`s. We
* may add support for `Object` data inputs in the future when the need
* arises.
*
* Instead, we perform basic checks for number of parameters and shapes.
*
* @param data: The input data.
* @param names: Name for the inputs, from the model.
* @param shapes: Expected shapes for the input data, from the model.
* @param checkBatchAxis: Whether the size along the batch axis (i.e., the
* first dimension) will be checked for matching.
* @param exceptionPrefix: Execption prefix message, used in generating error
* messages.
* @throws ValueError: on incorrect number of inputs or mismatches in shapes.
*/
function checkInputData(
data: Tensor|Tensor[], names: string[], shapes?: Shape[],
checkBatchAxis = true, exceptionPrefix = '') {
let arrays: Tensor[];
if (Array.isArray(data)) {
if (data.length !== names.length) {
throw new ValueError(
`Error when checking model ${exceptionPrefix}: the Array of ` +
`Tensors that you are passing to your model is not the size the ` +
`the model expected. Expected to see ${names.length} Tensor(s),` +
` but instead got ${data.length} Tensors(s).`);
}
arrays = data;
} else {
if (names.length > 1) {
throw new ValueError(
`The model expects ${names.length} ${exceptionPrefix} Tensors, ` +
`but only received one Tensor. Found: array with shape ` +
`${JSON.stringify(data.shape)}.`);
}
arrays = [data];
}
if (shapes != null) {
for (let i = 0; i < names.length; ++i) {
if (shapes[i] == null) {
continue;
}
const array = arrays[i];
if (array.shape.length !== shapes[i].length) {
throw new ValueError(
`Error when checking ${exceptionPrefix}: expected ${names[i]} ` +
`to have ${shapes[i].length} dimension(s), but got array with ` +
`shape ${JSON.stringify(array.shape)}`);
}
for (let j = 0; j < shapes[i].length; ++j) {
if (j === 0 && !checkBatchAxis) {
continue;
}
const dim = array.shape[j];
const refDim = shapes[i][j];
if (refDim != null) {
if (refDim !== dim) {
throw new ValueError(
`Error when checking ${exceptionPrefix}: expected ` +
`${names[i]} to have shape ${JSON.stringify(shapes[i])} but ` +
`got array with shape ${JSON.stringify(array.shape)}.`);
}
}
}
}
}
}
/**
* Maps metric functions to model outputs.
* @param metrics An shortcut strings name, metric function, `Array` or dict
* (`Object`) of metric functions.
* @param outputNames An `Array` of the names of model outputs.
* @returns An `Array` (one entry per model output) of `Array` of metric
* functions. For instance, if the model has 2 outputs, and for the first
* output we want to compute `binaryAccuracy` and `binaryCrossentropy`,
* and just `binaryAccuracy` for the second output, the `Array` would look
* like:
* `[[binaryAccuracy, binaryCrossentropy], [binaryAccuracy]]`
* @throws TypeError: incompatible metrics format.
*/
export function collectMetrics(
metrics: string|LossOrMetricFn|Array<string|LossOrMetricFn>|
{[outputName: string]: string | LossOrMetricFn},
outputNames: string[]): Array<Array<string|LossOrMetricFn>> {
if (metrics == null || Array.isArray(metrics) && metrics.length === 0) {
return outputNames.map(name => []);
}
let wrappedMetrics: Array<string|LossOrMetricFn>|
{[outputName: string]: string | LossOrMetricFn};
if (typeof metrics === 'string' || typeof metrics === 'function') {
wrappedMetrics = [metrics];
} else if (Array.isArray(metrics) || typeof metrics === 'object') {
wrappedMetrics = metrics as Array<string|LossOrMetricFn>|
{[outputName: string]: string} | {[outputName: string]: LossOrMetricFn};
} else {
throw new TypeError(
'Type of metrics argument not understood. Expected an string,' +
`function, Array, or Object, found: ${metrics}`);
}
if (Array.isArray(wrappedMetrics)) {
// We then apply all metrics to all outputs.
return outputNames.map(
name => wrappedMetrics as Array<string|LossOrMetricFn>);
} else {
// In this case, metrics is a dict.
const nestedMetrics: Array<Array<string|LossOrMetricFn>> = [];
for (const name of outputNames) {
let outputMetrics: string|LossOrMetricFn|Array<string|LossOrMetricFn> =
wrappedMetrics.hasOwnProperty(name) ? wrappedMetrics[name] : [];
if (!Array.isArray(outputMetrics)) {
outputMetrics = [outputMetrics];
}
nestedMetrics.push(outputMetrics);
}
return nestedMetrics;
}
}
export interface ModelEvaluateArgs {
/**
* Batch size (Integer). If unspecified, it will default to 32.
*/
batchSize?: number;
/**
* Verbosity mode.
*/
verbose?: ModelLoggingVerbosity;
/**
* Tensor of weights to weight the contribution of different samples to the
* loss and metrics.
*/
sampleWeight?: Tensor;
/**
* integer: total number of steps (batches of samples)
* before declaring the evaluation round finished. Ignored with the default
* value of `undefined`.
*/
steps?: number;
}
/**
* Configuration for calls to `LayersModel.compile()`.
*/
export interface ModelCompileArgs {
/**
* An instance of `tf.train.Optimizer` or a string name for an Optimizer.
*/
optimizer: string|Optimizer;
/**
* Object function(s) or name(s) of object function(s).
* If the model has multiple outputs, you can use a different loss
* on each output by passing a dictionary or an Array of losses.
* The loss value that will be minimized by the model will then be the sum
* of all individual losses.
*/
loss: string|string[]|{[outputName: string]: string}|LossOrMetricFn|
LossOrMetricFn[]|{[outputName: string]: LossOrMetricFn};
/**
* List of metrics to be evaluated by the model during training and testing.
* Typically you will use `metrics=['accuracy']`.
* To specify different metrics for different outputs of a multi-output
* model, you could also pass a dictionary.
*/
metrics?: string|LossOrMetricFn|Array<string|LossOrMetricFn>|
{[outputName: string]: string | LossOrMetricFn};
// TODO(cais): Add lossWeights, sampleWeightMode, weightedMetrics, and
// targetTensors.
}
const LAYERS_MODEL_FORMAT_NAME = 'layers-model';
/**
* A `tf.LayersModel` is a directed, acyclic graph of `tf.Layer`s plus methods
* for training, evaluation, prediction and saving.
*
* `tf.LayersModel` is the basic unit of training, inference and evaluation in
* TensorFlow.js. To create a `tf.LayersModel`, use `tf.LayersModel`.
*
* See also:
* `tf.Sequential`, `tf.loadLayersModel`.
*
* @doc {heading: 'Models', subheading: 'Classes'}
*/
export class LayersModel extends Container implements tfc.InferenceModel {
// The class name is 'Model' rather than 'LayersModel' for backwards
// compatibility since this class name shows up in the serialization format.
/** @nocollapse */
static className = 'Model';
protected optimizer_: Optimizer;
// Whether the model instance owns the optimizer: `true` if and only if
// `optimizer` is created from a string parameter during `compile()` call.
protected isOptimizerOwned: boolean;
loss: string|string[]|{[outputName: string]: string}|LossOrMetricFn|
LossOrMetricFn[]|{[outputName: string]: LossOrMetricFn};
lossFunctions: LossOrMetricFn[];
// TODO(cais): These private variables should probably not have the string
// 'feed' in their names, because we are not dealing with a symbolic
// backend.
private feedOutputShapes: Shape[];
private feedLossFns: LossOrMetricFn[];
private collectedTrainableWeights: LayerVariable[];
private testFunction: (data: Tensor[]) => Scalar[];
history: History;
// A public property that can be set by Callbacks to order early stopping
// during `fit()` calls.
protected stopTraining_: boolean;
protected isTraining: boolean;
metrics: string|LossOrMetricFn|Array<string|LossOrMetricFn>|
{[outputName: string]: string | LossOrMetricFn};
metricsNames: string[];
// Porting Note: `metrics_tensors` in PyKeras is a symbolic tensor. But given
// the imperative nature of tfjs-core, `metricsTensors` is a
// TypeScript function here.
// Also note that due to the imperative nature of tfjs-core, `metricsTensor`
// here needs an output index to keep track of which output of the
// LayersModel a metric belongs to. This is unlike `metrics_tensors` in
// PyKeras, which is a `list` of symbolic tensors, each of which has
// implicit "knowledge" of the outputs it depends on.
metricsTensors: Array<[LossOrMetricFn, number]>;
// User defind metadata (if any).
private userDefinedMetadata: {};
constructor(args: ContainerArgs) {
super(args);
this.isTraining = false;
}
/**
* Print a text summary of the model's layers.
*
* The summary includes
* - Name and type of all layers that comprise the model.
* - Output shape(s) of the layers
* - Number of weight parameters of each layer
* - If the model has non-sequential-like topology, the inputs each layer
* receives
* - The total number of trainable and non-trainable parameters of the model.
*
* ```js
* const input1 = tf.input({shape: [10]});
* const input2 = tf.input({shape: [20]});
* const dense1 = tf.layers.dense({units: 4}).apply(input1);
* const dense2 = tf.layers.dense({units: 8}).apply(input2);
* const concat = tf.layers.concatenate().apply([dense1, dense2]);
* const output =
* tf.layers.dense({units: 3, activation: 'softmax'}).apply(concat);
*
* const model = tf.model({inputs: [input1, input2], outputs: output});
* model.summary();
* ```
*
* @param lineLength Custom line length, in number of characters.
* @param positions Custom widths of each of the columns, as either
* fractions of `lineLength` (e.g., `[0.5, 0.75, 1]`) or absolute number
* of characters (e.g., `[30, 50, 65]`). Each number corresponds to
* right-most (i.e., ending) position of a column.
* @param printFn Custom print function. Can be used to replace the default
* `console.log`. For example, you can use `x => {}` to mute the printed
* messages in the console.
*
* @doc {heading: 'Models', subheading: 'Classes'}
*/
summary(
lineLength?: number, positions?: number[],
printFn:
// tslint:disable-next-line:no-any
(message?: any, ...optionalParams: any[]) => void = console.log) {
if (!this.built) {
throw new ValueError(
`This model has never been called, thus its weights have not been ` +
`created yet. So no summary can be displayed. Build the model ` +
`first (e.g., by calling it on some test data).`);
}
printSummary(this, lineLength, positions, printFn);
}
/**
* Configures and prepares the model for training and evaluation. Compiling
* outfits the model with an optimizer, loss, and/or metrics. Calling `fit`
* or `evaluate` on an un-compiled model will throw an error.
*
* @param args a `ModelCompileArgs` specifying the loss, optimizer, and
* metrics to be used for fitting and evaluating this model.
*
* @doc {heading: 'Models', subheading: 'Classes'}
*/
compile(args: ModelCompileArgs): void {
if (args.loss == null) {
args.loss = [];
}
this.loss = args.loss;
if (typeof args.optimizer === 'string') {
this.optimizer_ = optimizers.getOptimizer(args.optimizer);
this.isOptimizerOwned = true;
} else {
if (!(args.optimizer instanceof Optimizer)) {
throw new ValueError(
`User-defined optimizer must be an instance of tf.Optimizer.`);
}
this.optimizer_ = args.optimizer;
this.isOptimizerOwned = false;
}
// TODO(cais): Add lossWeights.
// TODO(cais): Add sampleWeightMode.
// Prepare loss functions.
let lossFunctions: LossOrMetricFn[] = [];
if (!Array.isArray(args.loss) && typeof args.loss !== 'string' &&
typeof args.loss !== 'function') {
args.loss = args.loss as {[outputName: string]: string};
for (const name in args.loss) {
if (this.outputNames.indexOf(name) === -1) {
throw new ValueError(
`Unknown entry in loss dictionary: "${name}". ` +
`Only expected the following keys: ${this.outputNames}`);
}
}
for (const name of this.outputNames) {
if (args.loss[name] == null) {
console.warn(
`Output "${name}" is missing from loss dictionary. We assume ` +
`this was done on purpose, and we will not be expecting data ` +
`to be passed to ${name} during training`);
}
lossFunctions.push(losses.get(args.loss[name]));
}
} else if (Array.isArray(args.loss)) {
if (args.loss.length !== this.outputs.length) {
throw new ValueError(
`When passing an Array as loss, it should have one entry per ` +
`model output. The model has ${this.outputs.length} output(s), ` +
`but you passed loss=${args.loss}.`);
}
const theLosses = args.loss as Array<string|LossOrMetricFn>;
lossFunctions = theLosses.map(l => losses.get(l));
} else {
const lossFunction = losses.get(args.loss);
this.outputs.forEach(_ => {
lossFunctions.push(lossFunction);
});
}
this.lossFunctions = lossFunctions;
this.feedOutputNames = [];
this.feedOutputShapes = [];
this.feedLossFns = [];
for (let i = 0; i < this.outputs.length; ++i) {
// TODO(cais): Logic for skipping target(s).
const shape = this.internalOutputShapes[i];
const name = this.outputNames[i];
this.feedOutputNames.push(name);
this.feedOutputShapes.push(shape);
this.feedLossFns.push(this.lossFunctions[i]);
}
// TODO(cais): Add logic for output masks.
// TODO(cais): Add logic for sample weights.
const skipTargetIndices: number[] = [];
// Prepare metrics.
this.metrics = args.metrics;
// TODO(cais): Add weightedMetrics.
this.metricsNames = ['loss'];
this.metricsTensors = [];
// Compute total loss.
// Porting Note: In PyKeras, metrics_tensors are symbolic tensor objects.
// Here, metricsTensors are TypeScript functions. This difference is due
// to the difference in symbolic/imperative property of the backends.
nameScope('loss', () => {
for (let i = 0; i < this.outputs.length; ++i) {
if (skipTargetIndices.indexOf(i) !== -1) {
continue;
}
// TODO(cais): Add weightedLoss, sampleWeight and mask.
// The following line should be weightedLoss
const weightedLoss = this.lossFunctions[i];
if (this.outputs.length > 1) {
this.metricsTensors.push([weightedLoss, i]);
this.metricsNames.push(this.outputNames[i] + '_loss');
}
}
// Porting Note: Due to the imperative nature of the backend, we calculate
// the regularizer penalties in the totalLossFunction, instead of here.
});
const nestedMetrics = collectMetrics(args.metrics, this.outputNames);
// TODO(cais): Add nestedWeightedMetrics.
/**
* Helper function used in loop below.
*/
const appendMetric =
(outputIndex: number, metricName: string,
metricTensor: LossOrMetricFn) => {
if (this.outputNames.length > 1) {
metricName = this.outputNames[outputIndex] + '_' + metricName;
}
this.metricsNames.push(metricName);
this.metricsTensors.push([metricTensor, outputIndex]);
};
nameScope('metric', () => {
for (let i = 0; i < this.outputs.length; ++i) {
if (skipTargetIndices.indexOf(i) !== -1) {
continue;
}
const outputMetrics = nestedMetrics[i];
// TODO(cais): Add weights and outputWeightedMetrics.
// TODO(cais): Add optional arg `weights` to the following function.
const handleMetrics = (metrics: Array<string|LossOrMetricFn>) => {
const metricNamePrefix = '';
let metricName: string;
let accFn: LossOrMetricFn;
let weightedMetricFn: LossOrMetricFn;
// TODO(cais): Use 'weights_' for weighted metrics.
for (const metric of metrics) {
if (typeof metric === 'string' &&
['accuracy', 'acc', 'crossentropy', 'ce'].indexOf(metric) !==
-1) {
const outputShape = this.internalOutputShapes[i];
if (outputShape[outputShape.length - 1] === 1 ||
this.lossFunctions[i] === losses.binaryCrossentropy) {
// case: binary accuracy/crossentropy.
if (['accuracy', 'acc'].indexOf(metric) !== -1) {
accFn = Metrics.binaryAccuracy;
} else if (['crossentropy', 'ce'].indexOf(metric) !== -1) {
accFn = Metrics.binaryCrossentropy;
}
} else if (
this.lossFunctions[i] ===
losses.sparseCategoricalCrossentropy) {
// case: categorical accuracy / crossentropy with sparse
// targets.
if (['accuracy', 'acc'].indexOf(metric) !== -1) {
accFn = Metrics.sparseCategoricalAccuracy;
} else if (['crossentropy', 'ce'].indexOf(metric) !== -1) {
accFn = Metrics.sparseCategoricalCrossentropy;
}
} else {
// case: categorical accuracy / crossentropy.
if (['accuracy', 'acc'].indexOf(metric) !== -1) {
accFn = Metrics.categoricalAccuracy;
} else if (['crossentropy', 'ce'].indexOf(metric) !== -1) {
accFn = Metrics.categoricalCrossentropy;
}
}
let suffix: string;
if (['accuracy', 'acc'].indexOf(metric) !== -1) {
suffix = 'acc';
} else if (['crossentropy', 'ce'].indexOf(metric) !== -1) {
suffix = 'ce';
}
// TODO(cais): Add weighting actually.
weightedMetricFn = accFn;
metricName = metricNamePrefix + suffix;
} else {
const metricFn = Metrics.get(metric);
// TODO(cais): Add weighting actually.
weightedMetricFn = metricFn;
metricName =
metricNamePrefix + Metrics.getLossOrMetricName(metric);
}
// TODO(cais): Add weighting and masking to metricResult.
let metricResult: LossOrMetricFn;
nameScope(metricName, () => {
metricResult = weightedMetricFn;
});
appendMetric(i, metricName, metricResult);
}
};
handleMetrics(outputMetrics);
// TODO(cais): Call handleMetrics with weights.
}
});
// Porting Notes: Given the imperative backend of tfjs-core,
// there is no need for constructing the symbolic graph and placeholders.
this.collectedTrainableWeights = this.trainableWeights;
}
/**
* Check trainable weights count consistency.
*
* This will raise a warning if `this.trainableWeights` and
* `this.collectedTrainableWeights` are inconsistent (i.e., have different
* numbers of parameters).
* Inconsistency will typically arise when one modifies `model.trainable`
* without calling `model.compile()` again.
*/
protected checkTrainableWeightsConsistency(): void {
if (this.collectedTrainableWeights == null) {
return;
}
if (this.trainableWeights.length !==
this.collectedTrainableWeights.length) {
console.warn(
'Discrepancy between trainableweights and collected trainable ' +
'weights. Did you set `model.trainable` without calling ' +
'`model.compile()` afterwards?');
}
}
/**
* Returns the loss value & metrics values for the model in test mode.
*
* Loss and metrics are specified during `compile()`, which needs to happen
* before calls to `evaluate()`.
*
* Computation is done in batches.
*
* ```js
* const model = tf.sequential({
* layers: [tf.layers.dense({units: 1, inputShape: [10]})]
* });
* model.compile({optimizer: 'sgd', loss: 'meanSquaredError'});
* const result = model.evaluate(
* tf.ones([8, 10]), tf.ones([8, 1]), {batchSize: 4});
* result.print();
* ```
*
* @param x `tf.Tensor` of test data, or an `Array` of `tf.Tensor`s if the
* model has multiple inputs.
* @param y `tf.Tensor` of target data, or an `Array` of `tf.Tensor`s if the
* model has multiple outputs.
* @param args A `ModelEvaluateArgs`, containing optional fields.
*
* @return `Scalar` test loss (if the model has a single output and no
* metrics) or `Array` of `Scalar`s (if the model has multiple outputs
* and/or metrics). The attribute `model.metricsNames`
* will give you the display labels for the scalar outputs.
*
* @doc {heading: 'Models', subheading: 'Classes'}
*/
evaluate(
x: Tensor|Tensor[], y: Tensor|Tensor[],
args: ModelEvaluateArgs = {}): Scalar|Scalar[] {
const batchSize = args.batchSize == null ? 32 : args.batchSize;
checkBatchSize(batchSize);
// TODO(cais): Standardize `config.sampleWeights` as well.
// Validate user data.
const checkBatchAxis = true;
const standardizedOuts =
this.standardizeUserDataXY(x, y, checkBatchAxis, batchSize);
try {
// TODO(cais): If uses `useLearningPhase`, set the corresponding element
// of the input to 0.
const ins = standardizedOuts[0].concat(standardizedOuts[1]);
this.makeTestFunction();
const f = this.testFunction;
const testOuts =
this.testLoop(f, ins, batchSize, args.verbose, args.steps);
return singletonOrArray(testOuts);
} finally {
disposeNewTensors(standardizedOuts[0], x);
disposeNewTensors(standardizedOuts[1], y);
}
}
// TODO(cais): Add code snippet below once real dataset objects are
// available.
/**
* Evaluate model using a dataset object.
*
* Note: Unlike `evaluate()`, this method is asynchronous (`async`).
*
* @param dataset A dataset object. Its `iterator()` method is expected
* to generate a dataset iterator object, the `next()` method of which
* is expected to produce data batches for evaluation. The return value
* of the `next()` call ought to contain a boolean `done` field and a
* `value` field. The `value` field is expected to be an array of two
* `tf.Tensor`s or an array of two nested `tf.Tensor` structures. The former
* case is for models with exactly one input and one output (e.g.
* a sequential model). The latter case is for models with multiple
* inputs and/or multiple outputs. Of the two items in the array, the
* first is the input feature(s) and the second is the output target(s).
* @param args A configuration object for the dataset-based evaluation.
* @returns Loss and metric values as an Array of `Scalar` objects.
*
* @doc {heading: 'Models', subheading: 'Classes'}
*/
async evaluateDataset(dataset: Dataset<{}>, args?: ModelEvaluateDatasetArgs):
Promise<Scalar|Scalar[]> {
this.makeTestFunction();
return evaluateDataset(this, dataset, args);
}
/**
* Get number of samples provided for training, evaluation or prediction.
*
* @param ins Input `tf.Tensor`.
* @param batchSize Integer batch size, optional.
* @param steps Total number of steps (batches of samples) before
* declaring loop finished. Optional.
* @param stepsName The public API's parameter name for `steps`.
* @returns Number of samples provided.
*/
private checkNumSamples(
ins: Tensor|Tensor[], batchSize?: number, steps?: number,
stepsName = 'steps'): number {
let numSamples: number;
if (steps != null) {
numSamples = null;
if (batchSize != null) {
throw new ValueError(
`If ${stepsName} is set, batchSize must be null or undefined.` +
`Got batchSize = ${batchSize}`);
}
} else if (ins != null) {
if (Array.isArray(ins)) {
numSamples = ins[0].shape[0];
} else {
numSamples = ins.shape[0];
}
} else {
throw new ValueError(
`Either the input data should have a defined shape, or ` +
`${stepsName} shoud be specified.`);
}
return numSamples;
}
/**
* Execute internal tensors of the model with input data feed.
* @param inputs Input data feed. Must match the inputs of the model.
* @param outputs Names of the output tensors to be fetched. Must match
* names of the SymbolicTensors that belong to the graph.
* @returns Fetched values for `outputs`.
*/
execute(inputs: Tensor|Tensor[]|NamedTensorMap, outputs: string|string[]):
Tensor|Tensor[] {
if (Array.isArray(outputs) && outputs.length === 0) {
throw new ValueError(
'`outputs` is an empty Array, which is not allowed.');
}
const outputsIsArray = Array.isArray(outputs);
const outputNames =
(outputsIsArray ? outputs : [outputs]);
const outputSymbolicTensors = this.retrieveSymbolicTensors(outputNames);
// Format the input into a FeedDict.
const feedDict = new FeedDict();
if (inputs instanceof Tensor) {
inputs = [inputs];
}
if (Array.isArray(inputs)) {
if (inputs.length !== this.inputs.length) {
throw new ValueError(
`The number of inputs provided (${inputs.length}) ` +
`does not match the number of inputs of this model ` +
`(${this.inputs.length}).`);
}
for (let i = 0; i < this.inputs.length; ++i) {
feedDict.add(this.inputs[i], inputs[i]);
}
} else {
for (const input of this.inputs) {
const tensorValue = inputs[input.name];
if (tensorValue == null) {
throw new ValueError(
`No value is provided for the model's input ${input.name}`);
}
feedDict.add(input, tensorValue);
}
}
// Run execution.
const executeOutputs = execute(outputSymbolicTensors, feedDict) as Tensor[];
return outputsIsArray ? executeOutputs : executeOutputs[0];
}
/**
* Retrieve the model's internal symbolic tensors from symbolic-tensor names.
*/
private retrieveSymbolicTensors(symbolicTensorNames: string[]):
SymbolicTensor[] {
const outputSymbolicTensors: SymbolicTensor[] =
pyListRepeat(null, symbolicTensorNames.length);
let outputsRemaining = symbolicTensorNames.length;
for (const layer of this.layers) {
const layerOutputs: SymbolicTensor[] =
Array.isArray(layer.output) ? layer.output : [layer.output];
const layerOutputNames = layerOutputs.map(output => output.name);
for (let i = 0; i < symbolicTensorNames.length; ++i) {
const index = layerOutputNames.indexOf(symbolicTensorNames[i]);
if (index !== -1) {
outputSymbolicTensors[i] = layerOutputs[index];
outputsRemaining--;
}
if (outputsRemaining === 0) {
break;
}
}