Skip to content
This repository was archived by the owner on Aug 15, 2019. It is now read-only.
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
8 changes: 4 additions & 4 deletions src/backends/backend.ts
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,7 @@
*/

import {Conv2DInfo, Conv3DInfo} from '../ops/conv_util';
import {Activation} from '../ops/fused_util';
import {Activation, FusedBatchMatMulConfig} from '../ops/fused_util';
import {Backend, DataId, Scalar, Tensor, Tensor1D, Tensor2D, Tensor3D, Tensor4D, Tensor5D} from '../tensor';
import {BackendValues, DataType, PixelData, Rank, ShapeMap} from '../types';

Expand Down Expand Up @@ -132,8 +132,8 @@ export class KernelBackend implements TensorStorage, Backend, BackendTimer {
}

fusedBatchMatMul(
a: Tensor3D, b: Tensor3D, transposeA: boolean, transposeB: boolean,
bias?: Tensor, activation?: Activation): Tensor3D {
{a, b, transposeA, transposeB, bias, activation, preluActivationWeights}:
FusedBatchMatMulConfig): Tensor3D {
throw new Error('Not yet implemented');
}

Expand Down Expand Up @@ -413,7 +413,7 @@ export class KernelBackend implements TensorStorage, Backend, BackendTimer {

fusedConv2d(
x: Tensor4D, filter: Tensor4D, convInfo: Conv2DInfo, bias?: Tensor4D,
activation?: Activation): Tensor4D {
activation?: Activation, preluActivationWeights?: Tensor): Tensor4D {
throw new Error('Not yet implemented');
}

Expand Down
21 changes: 14 additions & 7 deletions src/backends/cpu/backend_cpu.ts
Original file line number Diff line number Diff line change
Expand Up @@ -26,7 +26,7 @@ import * as broadcast_util from '../../ops/broadcast_util';
import * as concat_util from '../../ops/concat_util';
import {Conv2DInfo, Conv3DInfo} from '../../ops/conv_util';
import * as erf_util from '../../ops/erf_util';
import {Activation} from '../../ops/fused_util';
import {Activation, FusedBatchMatMulConfig} from '../../ops/fused_util';
import * as gather_nd_util from '../../ops/gather_nd_util';
import * as ops from '../../ops/ops';
import {buffer, scalar, tensor, tensor3d, tensor4d} from '../../ops/ops';
Expand All @@ -47,11 +47,14 @@ import {topkImpl} from '../topk_impl';
import {whereImpl} from '../where_impl';

function mapActivation(
backend: MathBackendCPU, activation: Activation, x: Tensor): Tensor {
backend: MathBackendCPU, x: Tensor, activation: Activation,
preluActivationWeights?: Tensor): Tensor {
if (activation === 'linear') {
return backend.linear(x);
} else if (activation === 'relu') {
return backend.relu(x);
} else if (activation === 'prelu') {
return backend.prelu(x, preluActivationWeights);
}
throw new Error(
`Activation ${activation} has not been implemented for the CPU backend.`);
Expand Down Expand Up @@ -522,14 +525,16 @@ export class MathBackendCPU implements KernelBackend {
}

fusedBatchMatMul(
a: Tensor3D, b: Tensor3D, transposeA: boolean, transposeB: boolean,
bias?: Tensor, activation?: Activation): Tensor3D {
{a, b, transposeA, transposeB, bias, activation, preluActivationWeights}:
FusedBatchMatMulConfig): Tensor3D {
let result = this.batchMatMul(a, b, transposeA, transposeB);
if (bias) {
result = this.add(result, bias) as Tensor3D;
}
if (activation) {
result = mapActivation(this, activation, result) as Tensor3D;
result =
mapActivation(this, result, activation, preluActivationWeights) as
Tensor3D;
}
return result;
}
Expand Down Expand Up @@ -1515,14 +1520,16 @@ export class MathBackendCPU implements KernelBackend {

fusedConv2d(
x: Tensor4D, filter: Tensor4D, convInfo: Conv2DInfo, bias?: Tensor4D,
activation?: Activation): Tensor4D {
activation?: Activation, preluActivationWeights?: Tensor): Tensor4D {
let result = this.conv2d(x, filter, convInfo);

if (bias) {
result = this.add(result, bias) as Tensor4D;
}
if (activation) {
result = mapActivation(this, activation, result) as Tensor4D;
result =
mapActivation(this, result, activation, preluActivationWeights) as
Tensor4D;
}
return result;
}
Expand Down
65 changes: 49 additions & 16 deletions src/backends/webgl/backend_webgl.ts
Original file line number Diff line number Diff line change
Expand Up @@ -28,7 +28,7 @@ import * as array_ops_util from '../../ops/array_ops_util';
import * as axis_util from '../../ops/axis_util';
import {computeOutShape} from '../../ops/concat_util';
import {Conv2DInfo, Conv3DInfo} from '../../ops/conv_util';
import {Activation} from '../../ops/fused_util';
import {Activation, FusedBatchMatMulConfig} from '../../ops/fused_util';
import * as gather_nd_util from '../../ops/gather_nd_util';
import * as reduce_util from '../../ops/reduce_util';
import * as scatter_nd_util from '../../ops/scatter_nd_util';
Expand Down Expand Up @@ -174,6 +174,11 @@ function mapActivationToShaderProgram(
return unary_packed_op.RELU;
}
return unary_op.RELU;
} else if (activation === 'prelu') {
if (packed) {
return binaryop_packed_gpu.PRELU;
}
return binaryop_gpu.PRELU;
}
throw new Error(`Activation ${
activation} has not been implemented for the WebGL backend.`);
Expand Down Expand Up @@ -865,26 +870,30 @@ export class MathBackendWebGL implements KernelBackend {
}

fusedBatchMatMul(
a: Tensor3D, b: Tensor3D, transposeA: boolean, transposeB: boolean,
bias?: Tensor, activation?: Activation): Tensor3D {
{a, b, transposeA, transposeB, bias, activation, preluActivationWeights}:
FusedBatchMatMulConfig): Tensor3D {
const outerShapeA = transposeA ? a.shape[2] : a.shape[1];
const outerShapeB = transposeB ? b.shape[1] : b.shape[2];
const [batch, , ] = a.shape;

const dtype = upcastType(a.dtype, b.dtype);

const hasBias = bias != null;
const hasPreluActivationWeights = preluActivationWeights != null;
const fusedActivation =
activation ? mapActivationToShaderProgram(activation, true) : null;
const program = new MatMulPackedProgram(
a.shape, [batch, outerShapeA, outerShapeB], transposeA, transposeB,
hasBias, fusedActivation);
hasBias, fusedActivation, hasPreluActivationWeights);
const output =
this.makePackedTensor(program.outputShape, dtype) as Tensor3D;
const inputs: TensorHandle[] = [a, b];
if (bias) {
inputs.push(bias);
}
if (preluActivationWeights) {
inputs.push(preluActivationWeights);
}
return this.compileAndRun<Tensor3D>(program, inputs, output);
}

Expand Down Expand Up @@ -1819,7 +1828,7 @@ export class MathBackendWebGL implements KernelBackend {

private conv2dByMatMul(
x: Tensor4D, filter: Tensor4D, convInfo: Conv2DInfo, bias?: Tensor4D,
activation?: Activation): Tensor4D {
activation?: Activation, preluActivationWeights?: Tensor): Tensor4D {
// Reshapes conv2D input to 2D tensors, uses matMul and then reshape the
// result from 2D to 4D.
const xShape = x.shape;
Expand Down Expand Up @@ -1850,9 +1859,15 @@ export class MathBackendWebGL implements KernelBackend {
Tensor3D;

return this.reshape<Rank.R4>(
this.fusedBatchMatMul(
xReshaped, filterReshaped, transposeA, transposeB, bias,
activation),
this.fusedBatchMatMul({
a: xReshaped,
b: filterReshaped,
transposeA,
transposeB,
bias,
activation,
preluActivationWeights
}),
convInfo.outShape);
}

Expand Down Expand Up @@ -1888,8 +1903,15 @@ export class MathBackendWebGL implements KernelBackend {
this.reshape(filter, [1, convInfo.inChannels, convInfo.outChannels]) as
Tensor3D;

const pointwiseConv = this.fusedBatchMatMul(
xReshaped, filterReshaped, transposeA, transposeB, bias, activation);
const pointwiseConv = this.fusedBatchMatMul({
a: xReshaped,
b: filterReshaped,
transposeA,
transposeB,
bias,
activation,
preluActivationWeights
});
const pointwiseConvTexData = this.texData.get(pointwiseConv.dataId);
util.assert(
pointwiseConvTexData.isPacked,
Expand All @@ -1906,7 +1928,7 @@ export class MathBackendWebGL implements KernelBackend {

private conv2dWithIm2Row(
x: Tensor4D, filter: Tensor4D, convInfo: Conv2DInfo, bias?: Tensor4D,
activation?: Activation): Tensor4D {
activation?: Activation, preluActivationWeights?: Tensor): Tensor4D {
// Rearranges conv2d input so each block to be convolved over forms the
// column of a new matrix with shape [filterWidth * filterHeight *
// inChannels, outHeight * outWidth]. The filter is also rearranged so each
Expand Down Expand Up @@ -1938,42 +1960,53 @@ export class MathBackendWebGL implements KernelBackend {
]) as Tensor3D;

const hasBias = bias != null;
const hasPreluActivationWeights = preluActivationWeights != null;
const fusedActivation =
activation ? mapActivationToShaderProgram(activation, true) : null;
const matmulProgram = new MatMulPackedProgram(
im2Col.shape, [1, numCols, convInfo.outChannels], transposeA,
transposeB, hasBias, fusedActivation);
transposeB, hasBias, fusedActivation, hasPreluActivationWeights);
const inputs: TensorHandle[] = [im2Col, w2Row];
if (bias) {
inputs.push(bias);
}
if (hasPreluActivationWeights) {
inputs.push(preluActivationWeights);
}
const product = this.compileAndRun<Tensor4D>(matmulProgram, inputs);

return product.reshape([1, outHeight, outWidth, convInfo.outChannels]);
}

fusedConv2d(
x: Tensor4D, filter: Tensor4D, convInfo: Conv2DInfo, bias?: Tensor4D,
activation?: Activation): Tensor4D {
activation?: Activation, preluActivationWeights?: Tensor): Tensor4D {
if (convInfo.filterHeight === 1 && convInfo.filterWidth === 1 &&
convInfo.dilationHeight === 1 && convInfo.dilationWidth === 1 &&
convInfo.strideHeight === 1 && convInfo.strideWidth === 1 &&
(convInfo.padInfo.type === 'SAME' ||
convInfo.padInfo.type === 'VALID')) {
return this.conv2dByMatMul(x, filter, convInfo, bias, activation);
return this.conv2dByMatMul(
x, filter, convInfo, bias, activation, preluActivationWeights);
}
if (ENV.getBool('WEBGL_CONV_IM2COL') && x.shape[0] === 1) {
return this.conv2dWithIm2Row(x, filter, convInfo, bias, activation);
return this.conv2dWithIm2Row(
x, filter, convInfo, bias, activation, preluActivationWeights);
}

const hasBias = bias != null;
const hasPreluActivationWeights = preluActivationWeights != null;
const fusedActivation =
activation ? mapActivationToShaderProgram(activation, false) : null;
const program = new Conv2DProgram(convInfo, hasBias, fusedActivation);
const program = new Conv2DProgram(
convInfo, hasBias, fusedActivation, hasPreluActivationWeights);
const inputs: TensorHandle[] = [x, filter];
if (bias) {
inputs.push(bias);
}
if (preluActivationWeights) {
inputs.push(preluActivationWeights);
}
return this.compileAndRun(program, inputs);
}

Expand Down
22 changes: 17 additions & 5 deletions src/backends/webgl/conv_gpu.ts
Original file line number Diff line number Diff line change
Expand Up @@ -24,7 +24,8 @@ export class Conv2DProgram implements GPGPUProgram {
userCode: string;

constructor(
convInfo: Conv2DInfo, addBias = false, activation: string = null) {
convInfo: Conv2DInfo, addBias = false, activation: string = null,
hasPreluActivationWeights = false) {
this.outputShape = convInfo.outShape;
const padTop = convInfo.padInfo.top;
const padLeft = convInfo.padInfo.left;
Expand All @@ -40,11 +41,18 @@ export class Conv2DProgram implements GPGPUProgram {

let activationSnippet = '', applyActivationSnippet = '';
if (activation) {
activationSnippet = `
float activation(float x) {
if (hasPreluActivationWeights) {
activationSnippet = `float activation(float a) {
float b = getPreluActivationWeightsAtOutCoords();
${activation}
}
`;
}`;
} else {
activationSnippet = `
float activation(float x) {
${activation}
}
`;
}

applyActivationSnippet = `result = activation(result);`;
}
Expand All @@ -54,6 +62,10 @@ export class Conv2DProgram implements GPGPUProgram {
this.variableNames.push('bias');
}

if (hasPreluActivationWeights) {
this.variableNames.push('preluActivationWeights');
}

this.userCode = `
${activationSnippet}

Expand Down
21 changes: 16 additions & 5 deletions src/backends/webgl/mulmat_packed_gpu.ts
Original file line number Diff line number Diff line change
Expand Up @@ -26,7 +26,7 @@ export class MatMulPackedProgram implements GPGPUProgram {
constructor(
aShape: [number, number, number], outputShape: [number, number, number],
transposeA = false, transposeB = false, addBias = false,
activation: string = null) {
activation: string = null, hasPreluActivation = false) {
this.outputShape = outputShape;

const sharedDim = transposeA ? aShape[1] : aShape[2];
Expand All @@ -39,9 +39,16 @@ export class MatMulPackedProgram implements GPGPUProgram {

let activationSnippet = '', applyActivationSnippet = '';
if (activation) {
activationSnippet = `vec4 activation(vec4 x) {
${activation}
}`;
if (hasPreluActivation) {
activationSnippet = `vec4 activation(vec4 a) {
vec4 b = getPreluActivationWeightsAtOutCoords();
${activation}
}`;
} else {
activationSnippet = `vec4 activation(vec4 x) {
${activation}
}`;
}

applyActivationSnippet = `result = activation(result);`;
}
Expand All @@ -51,6 +58,10 @@ export class MatMulPackedProgram implements GPGPUProgram {
this.variableNames.push('bias');
}

if (hasPreluActivation) {
this.variableNames.push('preluActivationWeights');
}

this.userCode = `
${activationSnippet}

Expand Down Expand Up @@ -82,4 +93,4 @@ export class MatMulPackedProgram implements GPGPUProgram {
}
`;
}
}
}
Loading