Skip to content

Commit

Permalink
Merge pull request #8 from JingyuanZhang/beta
Browse files Browse the repository at this point in the history
feat(core): support model frame & get fetch info in another way
  • Loading branch information
zhongkai authored Dec 10, 2020
2 parents 7e8c13d + 931d6ec commit dc4d9c4
Show file tree
Hide file tree
Showing 12 changed files with 161 additions and 34 deletions.
8 changes: 4 additions & 4 deletions packages/paddlejs-backend-webgpu/src/index.js
Original file line number Diff line number Diff line change
Expand Up @@ -32,10 +32,10 @@ WebGPUBackend.prototype.runProgram = function (type, opData, isRendered) {
this.submitEncodedCommands();
}

WebGPUBackend.prototype.read = async function (fetchOp) {
const fetchId = fetchOp.opData.outputTensors[0].tensorId;
const fetchShape = fetchOp.opData.outputTensors[0].shape;
const fetchByteLength = fetchShape.reduce((acc, cur) => acc * cur, fetchShape[0]) * 4;
WebGPUBackend.prototype.read = async function (fetchInfo) {
const fetchId = fetchInfo.name;
const fetchShape = fetchInfo.shape;
const fetchByteLength = fetchShape.reduce((acc, cur) => acc * cur, 1) * 4;
this.createReadBuffer({
size: fetchByteLength
});
Expand Down
99 changes: 99 additions & 0 deletions packages/paddlejs-backend-webgpu/test/model/mock/model.json
Original file line number Diff line number Diff line change
@@ -0,0 +1,99 @@
{
"ops": [
{
"attrs": {
"__@kernel_type_attr@__": "feed/def/1/4/2"
},
"inputs": {
"X": [
"feed"
]
},
"outputs": {
"Out": [
"image"
]
},
"type": "feed"
},
{
"attrs": {
"axis": 1,
"use_mkldnn": false,
"x_data_format": "",
"y_data_format": ""
},
"inputs": {
"X": [
"image"
],
"Y": [
"fc7_offset"
]
},
"outputs": {
"Out": [
"fc_0.tmp_1"
]
},
"type": "elementwise_add"
},
{
"attrs": {
"x_num_col_dims": 1,
"y_num_col_dims": 1
},
"inputs": {
"X": ["fc_0.tmp_1"],
"Y": ["fc7_weights"]
},
"outputs": {
"Out": ["fc_0.tmp_0"]
},
"type": "mul"
},
{
"attrs": {
"__@kernel_type_attr@__": "fetch/def/1/4/2",
"data_type": 1
},
"inputs": {
"X": [
"fc_0.tmp_0"
]
},
"outputs": {
"Out": [
"fetch"
]
},
"type": "fetch"
}
],
"vars": [
{
"name": "fc_0.tmp_1",
"persistable": false,
"shape": [1, 1, 2, 4]
},
{
"data": [1, 1 , 1, 1, 1, 1, 1, 1],
"name": "fc7_offset",
"persistable": true,
"shape": [
1, 1, 2, 4
]
},
{
"data": [4, 1, 0, -1, 1, 3, 2, 0, 1, 1, 3, 4],
"name": "fc7_weights",
"persistable": true,
"shape": [1, 1, 4, 3]
},
{
"name": "fc_0.tmp_0",
"persistable": false,
"shape": [1, 1, 2, 3]
}
]
}
21 changes: 21 additions & 0 deletions packages/paddlejs-backend-webgpu/test/model/modelTest.js
Original file line number Diff line number Diff line change
@@ -0,0 +1,21 @@
import { Runner } from 'paddlejs-core/src/index';
import registerWebGPUBackend from '../../src/index';

const modelDir = `/test/model/mock/`;
const modelPath = `${modelDir}model.json`;

async function run() {
const runner = new Runner({
modelPath,
feedShape: {
fw: 4,
fh: 2
}
});
await runner.init();
console.log(runner.weightMap);
console.log(await runner.preheat());
}

registerWebGPUBackend();
run();
Original file line number Diff line number Diff line change
@@ -1,14 +1,13 @@
import { Runner } from '../../paddlejs-core/src/index';
import registerWebGPUBackend from '../src/index';
import { Runner } from 'paddlejs-core/src/index';
import registerWebGPUBackend from '../../src/index';

const opName = 'mul';
const modelDir = `/test/data/`;
const opName = 'conv2d';
const modelDir = `/test/op/data/`;
const modelPath = `${modelDir}${opName}.json`;

async function run() {
const runner = new Runner({
modelPath,
fetchShape: [1, 4, 2, 2]
modelPath
});
await runner.init();
const executeOP = runner.weightMap[0];
Expand Down
8 changes: 8 additions & 0 deletions packages/paddlejs-core/src/commons/interface.ts
Original file line number Diff line number Diff line change
Expand Up @@ -89,3 +89,11 @@ export interface OpData {
export interface AttrsData {
[key: string]: any
}


export interface InputFeed {
data: Float32Array | number[];
shape: number[];
name: string;
canvas?: number[];
}
7 changes: 4 additions & 3 deletions packages/paddlejs-core/src/graph.ts
Original file line number Diff line number Diff line change
Expand Up @@ -191,10 +191,11 @@ export default class ModelGraph {

/**
* Get weightMap end Node FETCH
* @returns {OpExecutor}
* @returns {ModelVar}
*/
getFetchExecutor() : OpExecutor | undefined {
return this.weightMap.find((item: OpExecutor) => item.type === 'fetch');
getFetchExecutorInfo() : ModelVar {
const fetchOp: OpExecutor = this.weightMap.find((item: OpExecutor) => item.type === 'fetch') as OpExecutor;
return this.vars.find(item => item.name === fetchOp.inputs.X[0]) as ModelVar;
}

/**
Expand Down
10 changes: 6 additions & 4 deletions packages/paddlejs-core/src/opFactory/opDataBuilder.ts
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
import { ModelVar, OpExecutor, OpInputs, OpOutputs, AttrsData } from '../commons/interface';
import { ModelVar, OpExecutor, OpInputs, OpOutputs, AttrsData, InputFeed } from '../commons/interface';
import { GLOBALS } from '../globals';
import Tensor from './tensor';
import opBehaviors from './opBehaviors';
Expand All @@ -19,8 +19,9 @@ export default class OpData {
iLayer: number = 0;
program: any[] = [];
renderData: any[] = [];
inputFeed: InputFeed | undefined = {} as InputFeed;

constructor(op: OpExecutor, iLayer: number, vars: ModelVar[]) {
constructor(op: OpExecutor, iLayer: number, vars: ModelVar[], feed?: InputFeed) {
const {
type,
inputs,
Expand All @@ -35,6 +36,7 @@ export default class OpData {
this.checkMergeOp();
this.vars = vars;
this.iLayer = iLayer;
this.inputFeed = feed;

const isPass = this.checkIsPass();
if (isPass) {
Expand Down Expand Up @@ -71,8 +73,8 @@ export default class OpData {
});
});
Object.keys(this.input).forEach(key => {
if ((key === 'Input') || (key === 'X')) {
this.input[key] = this.getTensorAttr(this.input[key][0]);
if (this.input[key][0] === 'image') {
this.input[key] = [this.inputFeed];
}
else {
this.input[key] = this.getTensorAttr(this.input[key][0]);
Expand Down
2 changes: 1 addition & 1 deletion packages/paddlejs-core/src/opFactory/tensor.ts
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,7 @@ export default class Tensor {
shape_texture_packed: number[] = [];
shape_packed: number[] = [];
exceedMax: boolean = false;
data: Float32Array | number[] | null = [];
data: Float32Array | number[] | null = null;

constructor(opts: any = {}) {
this.opts = opts;
Expand Down
29 changes: 13 additions & 16 deletions packages/paddlejs-core/src/runner.ts
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
import Loader from './loader';
import Graph from './graph';
import { Model } from './commons/interface';
import { Model, InputFeed } from './commons/interface';
import OpData from './opFactory/opDataBuilder';
import { GLOBALS } from './globals';
import type OpExecutor from './opFactory/opExecutor';
Expand All @@ -22,12 +22,6 @@ interface ModelConfig {
needPreheat?: boolean; // 是否需要预热
}

interface InputFeed {
data: Float32Array | number[];
shape: number[];
name: string;
canvas?: number[];
}

export default class Runner {
// instance field
Expand All @@ -51,6 +45,7 @@ export default class Runner {
isExecuted: boolean = false;
test: boolean = false;
graphGenerator: Graph = {} as Graph;
feedData: InputFeed = {} as InputFeed;

constructor(options: ModelConfig | null) {
const opts = {
Expand All @@ -75,7 +70,6 @@ export default class Runner {
await GLOBALS.backendInstance.init();
await this.load();
this.genGraph();
this.genOpData();
}

async load() {
Expand All @@ -100,7 +94,7 @@ export default class Runner {
const type = op.type;
if (type !== 'feed' && type !== 'fetch') {
iLayer++;
const opData = new OpData(op, iLayer, vars);
const opData = new OpData(op, iLayer, vars, this.feedData);
op.opData = opData;
}
});
Expand All @@ -109,21 +103,21 @@ export default class Runner {
async preheat() {
await this.checkModelLoaded();
const { fh, fw } = this.modelConfig.feedShape;
const feed: InputFeed = {
data: new Float32Array(3 * fh * fw).fill(5.0),
const preheatFeed: InputFeed = {
data: new Float32Array(3 * fh * fw).fill(1.0),
name: 'image',
shape: [1, 3, fh, fw]
};
this.execute(feed);
const result = await this.execute(preheatFeed);
this.isExecuted = true;
return result;
}

private async checkModelLoaded() {
if (this.weightMap.length === 0) {
console.info('It\'s better to preheat the model before running.');
await this.load();
this.genGraph();
this.genOpData();
}
}

Expand All @@ -133,7 +127,10 @@ export default class Runner {
}

async execute(feed) {
console.log(feed);
this.feedData = feed;
if (!this.isExecuted) {
this.genOpData();
}
const FeedOp = this.graphGenerator.getFeedExecutor() as OpExecutor;
this.executeOp(FeedOp);
return await this.read();
Expand All @@ -152,7 +149,7 @@ export default class Runner {
}

async read() {
const fetchOp = this.graphGenerator.getFetchExecutor() as OpExecutor;
return await GLOBALS.backendInstance.read(fetchOp);
const fetchInfo = this.graphGenerator.getFetchExecutorInfo();
return await GLOBALS.backendInstance.read(fetchInfo);
}
};

0 comments on commit dc4d9c4

Please sign in to comment.