Skip to content

Commit fb819ed

Browse files
authored
Update to tfjs-core ~0.11.1 (tensorflow#93)
1 parent 604437c commit fb819ed

File tree

3 files changed

+38
-5
lines changed

3 files changed

+38
-5
lines changed

package.json

+2-2
Original file line numberDiff line numberDiff line change
@@ -20,7 +20,7 @@
2020
"publish-local": "yarn prep && yarn build && yalc push"
2121
},
2222
"devDependencies": {
23-
"@tensorflow/tfjs-core": "~0.10.1",
23+
"@tensorflow/tfjs-core": "~0.11.1",
2424
"@types/bindings": "~1.3.0",
2525
"@types/jasmine": "~2.8.6",
2626
"@types/node": "~9.6.2",
@@ -36,6 +36,6 @@
3636
"bindings": "~1.3.0"
3737
},
3838
"peerDependencies": {
39-
"@tensorflow/tfjs-core": "~0.10.1"
39+
"@tensorflow/tfjs-core": "~0.11.0"
4040
}
4141
}

src/nodejs_kernel_backend.ts

+33
Original file line numberDiff line numberDiff line change
@@ -19,6 +19,7 @@
1919
import {BackendTimingInfo, DataType, fill, KernelBackend, ones, Rank, rsqrt, scalar, ShapeMap, Tensor, Tensor1D, tensor1d, Tensor2D, tensor2d, Tensor3D, Tensor4D} from '@tensorflow/tfjs-core';
2020
import {Conv2DInfo} from '@tensorflow/tfjs-core/dist/ops/conv_util';
2121
import {upcastType} from '@tensorflow/tfjs-core/dist/types';
22+
2223
import {TensorMetadata, TFEOpAttr, TFJSBinding} from './tfjs_binding';
2324

2425
type TensorInfo = {
@@ -190,6 +191,26 @@ export class NodeJSKernelBackend implements KernelBackend {
190191
return this.executeSingleOutput('MatMul', opAttrs, [a, b]) as Tensor2D;
191192
}
192193

194+
stridedSlice<T extends Tensor<Rank>>(
195+
x: T, begin: number[], end: number[], strides: number[],
196+
beginMask: number, endMask: number): T {
197+
const beginTensor = tensor1d(begin, 'int32');
198+
const endTensor = tensor1d(end, 'int32');
199+
const stridesTensor = tensor1d(strides, 'int32');
200+
const opAttrs = [
201+
this.createTypeOpAttr('T', x.dtype),
202+
this.createTypeOpAttr('Index', 'int32'),
203+
{name: 'begin_mask', type: this.binding.TF_ATTR_INT, value: beginMask},
204+
{name: 'end_mask', type: this.binding.TF_ATTR_INT, value: endMask},
205+
{name: 'ellipsis_mask', type: this.binding.TF_ATTR_INT, value: 0},
206+
{name: 'new_axis_mask', type: this.binding.TF_ATTR_INT, value: 0},
207+
{name: 'shrink_axis_mask', type: this.binding.TF_ATTR_INT, value: 0}
208+
];
209+
return this.executeSingleOutput(
210+
'StridedSlice', opAttrs,
211+
[x, beginTensor, endTensor, stridesTensor]) as T;
212+
}
213+
193214
slice<T extends Tensor>(x: T, begin: number[], size: number[]): T {
194215
const opAttrs = [
195216
this.createTypeOpAttr('T', x.dtype),
@@ -901,6 +922,18 @@ export class NodeJSKernelBackend implements KernelBackend {
901922
]) as Tensor2D;
902923
}
903924

925+
cumsum(x: Tensor<Rank>, axis: number, exclusive: boolean, reverse: boolean):
926+
Tensor<Rank> {
927+
const axisTensor = scalar(axis, 'int32');
928+
const opAttrs = [
929+
{name: 'exclusive', type: this.binding.TF_ATTR_BOOL, value: exclusive},
930+
{name: 'reverse', type: this.binding.TF_ATTR_BOOL, value: reverse},
931+
this.createTypeOpAttr('T', x.dtype),
932+
this.createTypeOpAttr('Tidx', 'int32')
933+
];
934+
return this.executeSingleOutput('Cumsum', opAttrs, [x, axisTensor]);
935+
}
936+
904937
fromPixels(
905938
pixels: ImageData|HTMLImageElement|HTMLCanvasElement|HTMLVideoElement,
906939
numChannels: number): Tensor3D {

yarn.lock

+3-3
Original file line numberDiff line numberDiff line change
@@ -2,9 +2,9 @@
22
# yarn lockfile v1
33

44

5-
"@tensorflow/tfjs-core@~0.10.1":
6-
version "0.10.1"
7-
resolved "https://registry.yarnpkg.com/@tensorflow/tfjs-core/-/tfjs-core-0.10.1.tgz#1764c8c122d5b8a86a8d23c8ffd58bef8e40eda9"
5+
"@tensorflow/tfjs-core@~0.11.1":
6+
version "0.11.1"
7+
resolved "https://registry.yarnpkg.com/@tensorflow/tfjs-core/-/tfjs-core-0.11.1.tgz#d26808f912529668d0a41228da37566b6b2f4f08"
88
dependencies:
99
seedrandom "~2.4.3"
1010

0 commit comments

Comments
 (0)