Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
33 changes: 31 additions & 2 deletions tfjs-backend-webgpu/src/backend_webgpu.ts
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,7 @@

import './flags_webgpu';

import {backend_util, buffer, DataStorage, DataType, engine, env, GPUData, KernelBackend, Rank, RecursiveArray, ShapeMap, TensorBuffer, TensorInfo, TimingInfo, TypedArray, util} from '@tensorflow/tfjs-core';
import {backend_util, buffer, DataStorage, DataType, engine, env, GPUData, GPUReadData, KernelBackend, Rank, RecursiveArray, ShapeMap, TensorBuffer, TensorInfo, TimingInfo, TypedArray, util} from '@tensorflow/tfjs-core';

import {BufferManager} from './buffer_manager';
import {TextureManager} from './texture_manager';
Expand Down Expand Up @@ -50,6 +50,7 @@ type TensorData = {
shape: number[],
refCount: number,
resourceInfo?: BufferInfo|TextureInfo,
external?: boolean,
// For complex numbers, the real and imaginary parts are stored as their own
// individual tensors, with a parent joining the two with the
// complexTensorInfos field.
Expand Down Expand Up @@ -232,7 +233,7 @@ export class WebGPUBackend extends KernelBackend {

releaseResource(dataId: DataId) {
const tensorData = this.tensorMap.get(dataId);
if (!tensorData || !tensorData.resourceInfo) {
if (!tensorData || !tensorData.resourceInfo || tensorData.external) {
return;
}
if ('texture' in tensorData.resourceInfo) {
Expand Down Expand Up @@ -287,6 +288,30 @@ export class WebGPUBackend extends KernelBackend {
return dataId;
}

writeFromGPUBuffer(values: GPUReadData): DataId {
if (values.dtype === 'complex64') {
throw new Error(`Cannot write to a complex64 dtype. `);
}
const dataId = {id: this.nextDataId()};
this.tensorMap.set(dataId, {
dtype: values.dtype,
shape: values.shape,
values: null,
refCount: 1,
external: true
});
const tensorData = this.tensorMap.get(dataId);
const size = webgpu_util.GPUBytesPerElement(tensorData.dtype) *
util.sizeFromShape(tensorData.shape);

tensorData.resourceInfo = {
size,
usage: this.defaultGpuBufferUsage(),
buffer: values.buffer
};
return dataId;
}

move(
dataId: DataId, values: backend_util.BackendValues, shape: number[],
dtype: DataType, refCount: number): void {
Expand Down Expand Up @@ -838,6 +863,10 @@ export class WebGPUBackend extends KernelBackend {
return this.tensorMap.numDataIds() - this.tensorDataPendingDisposal.length;
}

getGPUDevice(): GPUDevice {
return this.device;
}

dispose() {
if (this.disposed) {
return;
Expand Down
140 changes: 140 additions & 0 deletions tfjs-backend-webgpu/src/tensor_test.ts
Original file line number Diff line number Diff line change
@@ -0,0 +1,140 @@
/**
* @license
* Copyright 2022 Google LLC. All Rights Reserved.
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
* =============================================================================
*/

import * as tf from '@tensorflow/tfjs-core';
import {describeWebGPU} from './test_util';

async function createReadonlyGPUBufferFromData(
device: GPUDevice, data: number[], dtype: string) {
const bytesPerElement = 4;
const sizeInBytes = data.length * bytesPerElement;

const gpuWriteBuffer = device.createBuffer({
mappedAtCreation: true,
size: sizeInBytes,
usage: GPUBufferUsage.MAP_WRITE | GPUBufferUsage.COPY_SRC
});
const arrayBuffer = gpuWriteBuffer.getMappedRange();

new Float32Array(arrayBuffer).set(data);

gpuWriteBuffer.unmap();

const aBuffer = device.createBuffer({
mappedAtCreation: false,
size: sizeInBytes,
usage: GPUBufferUsage.COPY_DST | GPUBufferUsage.STORAGE
});

const copyEncoder = device.createCommandEncoder();
copyEncoder.copyBufferToBuffer(gpuWriteBuffer, 0, aBuffer, 0, sizeInBytes);

const copyCommands = copyEncoder.finish();
device.queue.submit([copyCommands]);
gpuWriteBuffer.destroy();
return aBuffer;
}

type Shape = [number]|[number, number]|[number, number, number]|
[number, number, number, number]|[number, number, number, number, number]|
[number, number, number, number, number, number];

describeWebGPU('tensor', () => {
it('tensor from GPUBuffer 1d 2d 3d 4d 5d 6d', async () => {
const device = tf.backend().getGPUDevice();
const aData = [1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16];
const b =
new Float32Array([1, 2, 3, 4, 1, 2, 3, 4, 1, 2, 3, 4, 1, 2, 3, 4]);
const expected = [2, 4, 6, 8, 6, 8, 10, 12, 10, 12, 14, 16, 14, 16, 18, 20];
const dtype = 'float32';
const aBuffer = await createReadonlyGPUBufferFromData(device, aData, dtype);
{
const shape: Shape = [16];
const gpuReadData = new tf.GPUReadData(aBuffer, shape, dtype);
const a = tf.tensor1d(gpuReadData);
const result = tf.add(a, tf.tensor1d(b, 'float32', shape));
tf.test_util.expectArraysClose(await result.data(), expected);
a.dispose();
result.dispose();
}
{
const shape: Shape = [8, 2];
const gpuReadData = new tf.GPUReadData(aBuffer, shape, dtype);
const a = tf.tensor2d(gpuReadData);
const result = tf.add(a, tf.tensor2d(b, shape));
tf.test_util.expectArraysClose(await result.data(), expected);
a.dispose();
result.dispose();
}
{
const shape: Shape = [2, 4, 2];
const gpuReadData = new tf.GPUReadData(aBuffer, shape, dtype);
const a = tf.tensor3d(gpuReadData);
const result = tf.add(a, tf.tensor3d(b, shape));
tf.test_util.expectArraysClose(await result.data(), expected);
a.dispose();
result.dispose();
}
{
const shape: Shape = [2, 2, 2, 2];
const gpuReadData = new tf.GPUReadData(aBuffer, shape, dtype);
const a = tf.tensor4d(gpuReadData);
const result = tf.add(a, tf.tensor4d(b, shape));
tf.test_util.expectArraysClose(await result.data(), expected);
a.dispose();
result.dispose();
}
{
const shape: Shape = [1, 2, 2, 2, 2];
const gpuReadData = new tf.GPUReadData(aBuffer, shape, dtype);
const a = tf.tensor5d(gpuReadData);
const result = tf.add(a, tf.tensor5d(b, shape));
tf.test_util.expectArraysClose(await result.data(), expected);
a.dispose();
result.dispose();
}
{
const shape: Shape = [1, 1, 2, 2, 2, 2];
const gpuReadData = new tf.GPUReadData(aBuffer, shape, dtype);
const a = tf.tensor6d(gpuReadData);
const result = tf.add(a, tf.tensor6d(b, shape));
tf.test_util.expectArraysClose(await result.data(), expected);
a.dispose();
result.dispose();
}
aBuffer.destroy();
});

it('two tensors share the same GPUBuffer', async () => {
const device = tf.backend().getGPUDevice();
const aData = [1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16];
const expected =
[2, 4, 6, 8, 10, 12, 14, 16, 18, 20, 22, 24, 26, 28, 30, 32];
const dtype = 'float32';
const aBuffer = await createReadonlyGPUBufferFromData(device, aData, dtype);

const shape: Shape = [16];
const gpuReadData = new tf.GPUReadData(aBuffer, shape, dtype);
const a = tf.tensor1d(gpuReadData);
const b = tf.tensor1d(gpuReadData);
const result = tf.add(a, b);
tf.test_util.expectArraysClose(await result.data(), expected);
a.dispose();
result.dispose();
aBuffer.destroy();
});
});
9 changes: 8 additions & 1 deletion tfjs-core/src/backends/backend.ts
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,7 @@
* =============================================================================
*/

import {Backend, DataId, DataToGPUOptions, GPUData} from '../tensor';
import {Backend, DataId, DataToGPUOptions, GPUData, GPUReadData} from '../tensor';
import {BackendValues, DataType} from '../types';

export const EPSILON_FLOAT32 = 1e-7;
Expand Down Expand Up @@ -127,6 +127,9 @@ export class KernelBackend implements TensorStorage, Backend, BackendTimer {
write(values: BackendValues, shape: number[], dtype: DataType): DataId {
return notYetImplemented('write');
}
writeFromGPUBuffer(vauels: GPUReadData): DataId {
return notYetImplemented('writeFromGPUBuffer');
}
move(
dataId: DataId, values: BackendValues, shape: number[], dtype: DataType,
refCount: number): void {
Expand All @@ -139,6 +142,10 @@ export class KernelBackend implements TensorStorage, Backend, BackendTimer {
floatPrecision(): 16|32 {
return notYetImplemented('floatPrecision');
}
/** Return the GPUDevice (WebGPU backend only) */
getGPUDevice(): GPUDevice {
return notYetImplemented('getDevice');
}
/** Returns the smallest representable number. */
epsilon(): number {
return this.floatPrecision() === 32 ? EPSILON_FLOAT32 : EPSILON_FLOAT16;
Expand Down
2 changes: 1 addition & 1 deletion tfjs-core/src/base.ts
Original file line number Diff line number Diff line change
Expand Up @@ -53,7 +53,7 @@ export {Optimizer} from './optimizers/optimizer';
export {OptimizerConstructors} from './optimizers/optimizer_constructors';
export {RMSPropOptimizer} from './optimizers/rmsprop_optimizer';
export {SGDOptimizer} from './optimizers/sgd_optimizer';
export {DataToGPUOptions, DataToGPUWebGLOption, GPUData, Scalar, Tensor, Tensor1D, Tensor2D, Tensor3D, Tensor4D, Tensor5D, TensorBuffer, Variable} from './tensor';
export {DataToGPUOptions, DataToGPUWebGLOption, GPUData, GPUReadData, Scalar, Tensor, Tensor1D, Tensor2D, Tensor3D, Tensor4D, Tensor5D, TensorBuffer, Variable} from './tensor';
export {GradSaveFunc, NamedTensorMap, TensorContainer, TensorContainerArray, TensorContainerObject} from './tensor_types';
export {BackendValues, DataType, DataTypeMap, DataValues, NumericDataType, PixelData, Rank, RecursiveArray, ScalarLike, ShapeMap, sumOutType, TensorLike, TypedArray, upcastType} from './types';

Expand Down
25 changes: 22 additions & 3 deletions tfjs-core/src/engine.ts
Original file line number Diff line number Diff line change
Expand Up @@ -23,7 +23,7 @@ import {getGradient, getKernel, getKernelsForBackend, GradFunc, NamedAttrMap, Te
import * as log from './log';
import {KernelProfile, Profiler} from './profiler';
import {backpropagateGradients, getFilteredNodesXToY, TapeNode} from './tape';
import {DataId, DataToGPUOptions, GPUData, setTensorTracker, Tensor, TensorTracker, Variable} from './tensor';
import {DataId, DataToGPUOptions, GPUData, GPUReadData, setTensorTracker, Tensor, TensorTracker, Variable} from './tensor';
import {GradSaveFunc, NamedTensorMap, NamedVariableMap, TensorContainer} from './tensor_types';
import {getTensorsInContainer} from './tensor_util';
import {BackendValues, DataType, DataValues} from './types';
Expand Down Expand Up @@ -823,15 +823,34 @@ export class Engine implements TensorTracker, DataMover {
return t;
}

/**
* Internal method used by public APIs for tensor creation. Makes a new
* tensor with the provided shape, dtype and values. It always creates a new
* data id and uses the data from the buffer.
*/
makeTensorFromGPUBuffer(values: GPUReadData, backend?: KernelBackend):
Tensor {
if (values.buffer == null) {
throw new Error('Values passed to engine.makeTensor() are null');
}
backend = backend || this.backend;

const dataId = backend.writeFromGPUBuffer(values);
const t =
new Tensor(values.shape, values.dtype, dataId, this.nextTensorId());
this.trackTensor(t, backend);
return t;
}

/**
* Internal method used by backends. Makes a new tensor
* that is a wrapper around an existing data id. It doesn't create
* a new data id, only increments the ref count used in memory tracking.
* @deprecated
*/
makeTensorFromDataId(
dataId: DataId, shape: number[], dtype: DataType,
backend?: KernelBackend): Tensor {
dataId: DataId, shape: number[], dtype: DataType,
backend?: KernelBackend): Tensor {
dtype = dtype || 'float32';
const tensorInfo: TensorInfo = {dataId, shape, dtype};
return this.makeTensorFromTensorInfo(tensorInfo, backend);
Expand Down
26 changes: 17 additions & 9 deletions tfjs-core/src/ops/tensor1d.ts
Original file line number Diff line number Diff line change
Expand Up @@ -15,11 +15,13 @@
* =============================================================================
*/

import {Tensor1D} from '../tensor';
import {ENGINE} from '../engine';
import {GPUReadData, Tensor1D} from '../tensor';
import {inferShape} from '../tensor_util_env';
import {TensorLike1D} from '../types';
import {DataType} from '../types';
import {assertNonNull} from '../util';

import {makeTensor} from './tensor_ops_util';

/**
Expand All @@ -33,17 +35,23 @@ import {makeTensor} from './tensor_ops_util';
* ```
*
* @param values The values of the tensor. Can be array of numbers,
* or a `TypedArray`.
* or a `TypedArray`, or a `GPUReadData`.
* @param dtype The data type.
*
* @doc {heading: 'Tensors', subheading: 'Creation'}
*/
export function tensor1d(values: TensorLike1D, dtype?: DataType): Tensor1D {
assertNonNull(values);
const inferredShape = inferShape(values, dtype);
if (inferredShape.length !== 1) {
throw new Error('tensor1d() requires values to be a flat/TypedArray');
export function tensor1d(
values: TensorLike1D|GPUReadData, dtype?: DataType,
shape?: number[]): Tensor1D {
if (values instanceof GPUReadData) {
return ENGINE.makeTensorFromGPUBuffer(values) as Tensor1D;
} else {
assertNonNull(values);
const inferredShape = inferShape(values, dtype);
if (inferredShape.length !== 1) {
throw new Error('tensor1d() requires values to be a flat/TypedArray');
}
const shape: number[] = null;
return makeTensor(values, shape, inferredShape, dtype) as Tensor1D;
}
const shape: number[] = null;
return makeTensor(values, shape, inferredShape, dtype) as Tensor1D;
}
Loading