Skip to content

Commit

Permalink
Add kernel timing information to tf.profile() (tensorflow#3721)
Browse files Browse the repository at this point in the history
FEATURE
  • Loading branch information
Linchenn authored Aug 6, 2020
1 parent 2c59dfb commit 3d14962
Show file tree
Hide file tree
Showing 2 changed files with 64 additions and 30 deletions.
18 changes: 14 additions & 4 deletions tfjs-core/src/engine.ts
Original file line number Diff line number Diff line change
Expand Up @@ -58,6 +58,8 @@ type KernelInfo = {
totalTensorsSnapshot: number;
inputShapes: number[][];
outputShapes: number[][];
kernelTimeMs: number | {error: string} | Promise<number|{error: string}>;
extraInfo: string | Promise<string>;
};

export type ProfileInfo = {
Expand Down Expand Up @@ -625,15 +627,17 @@ export class Engine implements TensorTracker, DataMover {
}

// Stop recording to a tape when running a kernel.
let kernelProfile: KernelProfile;
this.scopedRun(
() => this.state.kernelDepth++, () => this.state.kernelDepth--, () => {
if (!this.ENV.getBool('DEBUG')) {
if (!this.ENV.getBool('DEBUG') && !this.state.profiling) {
outputs = kernelFunc();
} else {
let kernelProfile: KernelProfile;
kernelProfile = this.profiler.profileKernel(
kernelName, inputs, () => kernelFunc());
this.profiler.logKernelProfile(kernelProfile);
if (this.ENV.getBool('DEBUG')) {
this.profiler.logKernelProfile(kernelProfile);
}
outputs = kernelProfile.outputs;
}
});
Expand All @@ -652,7 +656,9 @@ export class Engine implements TensorTracker, DataMover {
totalTensorsSnapshot: this.state.numTensors,
inputShapes: Object.keys(inputs).map(
key => inputs[key] != null ? inputs[key].shape : null),
outputShapes: outputs.map(item => item.shape)
outputShapes: outputs.map(item => item.shape),
kernelTimeMs: kernelProfile.timeMs,
extraInfo: kernelProfile.extraInfo
});
}
return (Array.isArray(out) ? outputs : outputs[0]) as T;
Expand Down Expand Up @@ -878,6 +884,10 @@ export class Engine implements TensorTracker, DataMover {
this.state.activeProfile.newBytes = this.state.numBytes - startBytes;
this.state.activeProfile.newTensors =
this.state.numTensors - startNumTensors;
for (const kernel of this.state.activeProfile.kernels) {
kernel.kernelTimeMs = await kernel.kernelTimeMs;
kernel.extraInfo = await kernel.extraInfo;
}
return this.state.activeProfile;
}

Expand Down
76 changes: 50 additions & 26 deletions tfjs-core/src/engine_test.ts
Original file line number Diff line number Diff line change
Expand Up @@ -387,26 +387,40 @@ describeWithFlags('profile', ALL_ENVS, () => {
expect(profile.peakBytes).toBe(24);
expect(profile.newTensors).toBe(1);
expectArraysClose(await result.data(), [1, 2, 3]);
expect(profile.kernels).toEqual([
{
'name': 'Square',
'bytesAdded': 12,
'totalBytesSnapshot': 24,
'tensorsAdded': 1,
'totalTensorsSnapshot': 2,
'inputShapes': [[3]],
'outputShapes': [[3]]
},
{
'name': 'Square',
'bytesAdded': 12,
'totalBytesSnapshot': 24,
'tensorsAdded': 1,
'totalTensorsSnapshot': 2,
'inputShapes': [[3]],
'outputShapes': [[3]]
}
]);
expect(profile.kernels.length).toBe(2);

// Test the types for `kernelTimeMs` and `extraInfo` to confirm the promises
// are resolved.
expect(typeof profile.kernels[0].kernelTimeMs).toBe('number');
expect(typeof profile.kernels[0].extraInfo).toBe('string');
expect(typeof profile.kernels[1].kernelTimeMs).toBe('number');
expect(typeof profile.kernels[1].extraInfo).toBe('string');

// The specific values of `kernelTimeMs` and `extraInfo` are tested in the
// tests of Profiler.profileKernel, so their values are not tested here.
expect(profile.kernels[0]).toEqual({
'name': 'Square',
'bytesAdded': 12,
'totalBytesSnapshot': 24,
'tensorsAdded': 1,
'totalTensorsSnapshot': 2,
'inputShapes': [[3]],
'outputShapes': [[3]],
'kernelTimeMs': profile.kernels[0].kernelTimeMs,
'extraInfo': profile.kernels[0].extraInfo
});

expect(profile.kernels[1]).toEqual({
'name': 'Square',
'bytesAdded': 12,
'totalBytesSnapshot': 24,
'tensorsAdded': 1,
'totalTensorsSnapshot': 2,
'inputShapes': [[3]],
'outputShapes': [[3]],
'kernelTimeMs': profile.kernels[1].kernelTimeMs,
'extraInfo': profile.kernels[1].extraInfo
});
});

it('squaring without disposing', async () => {
Expand All @@ -422,15 +436,20 @@ describeWithFlags('profile', ALL_ENVS, () => {
expect(profile.peakBytes).toBe(24);
expect(profile.newTensors).toBe(2);
expectArraysClose(await result.data(), [1, 4, 9]);
expect(profile.kernels).toEqual([{
expect(profile.kernels.length).toBe(1);
expect(typeof profile.kernels[0].kernelTimeMs).toBe('number');
expect(typeof profile.kernels[0].extraInfo).toBe('string');
expect(profile.kernels[0]).toEqual({
'name': 'Square',
'bytesAdded': 12,
'totalBytesSnapshot': 24,
'tensorsAdded': 1,
'totalTensorsSnapshot': 2,
'inputShapes': [[3]],
'outputShapes': [[3]]
}]);
'outputShapes': [[3]],
'kernelTimeMs': profile.kernels[0].kernelTimeMs,
'extraInfo': profile.kernels[0].extraInfo
});
});

it('squaring in async query', async () => {
Expand All @@ -448,15 +467,20 @@ describeWithFlags('profile', ALL_ENVS, () => {
expect(profile.peakBytes).toBe(24);
expect(profile.newTensors).toBe(1);
expectArraysClose(await result.data(), [1, 2, 3]);
expect(profile.kernels).toEqual([{
expect(profile.kernels.length).toBe(1);
expect(typeof profile.kernels[0].kernelTimeMs).toBe('number');
expect(typeof profile.kernels[0].extraInfo).toBe('string');
expect(profile.kernels[0]).toEqual({
'name': 'Square',
'bytesAdded': 12,
'totalBytesSnapshot': 24,
'tensorsAdded': 1,
'totalTensorsSnapshot': 2,
'inputShapes': [[3]],
'outputShapes': [[3]]
}]);
'outputShapes': [[3]],
'kernelTimeMs': profile.kernels[0].kernelTimeMs,
'extraInfo': profile.kernels[0].extraInfo
});
});
});

Expand Down

0 comments on commit 3d14962

Please sign in to comment.