|
19 | 19 | import {BackendTimingInfo, DataType, fill, KernelBackend, ones, Rank, rsqrt, scalar, ShapeMap, Tensor, Tensor1D, tensor1d, Tensor2D, tensor2d, Tensor3D, Tensor4D} from '@tensorflow/tfjs-core';
|
20 | 20 | import {Conv2DInfo} from '@tensorflow/tfjs-core/dist/ops/conv_util';
|
21 | 21 | import {upcastType} from '@tensorflow/tfjs-core/dist/types';
|
| 22 | + |
22 | 23 | import {TensorMetadata, TFEOpAttr, TFJSBinding} from './tfjs_binding';
|
23 | 24 |
|
24 | 25 | type TensorInfo = {
|
@@ -190,6 +191,26 @@ export class NodeJSKernelBackend implements KernelBackend {
|
190 | 191 | return this.executeSingleOutput('MatMul', opAttrs, [a, b]) as Tensor2D;
|
191 | 192 | }
|
192 | 193 |
|
| 194 | + stridedSlice<T extends Tensor<Rank>>( |
| 195 | + x: T, begin: number[], end: number[], strides: number[], |
| 196 | + beginMask: number, endMask: number): T { |
| 197 | + const beginTensor = tensor1d(begin, 'int32'); |
| 198 | + const endTensor = tensor1d(end, 'int32'); |
| 199 | + const stridesTensor = tensor1d(strides, 'int32'); |
| 200 | + const opAttrs = [ |
| 201 | + this.createTypeOpAttr('T', x.dtype), |
| 202 | + this.createTypeOpAttr('Index', 'int32'), |
| 203 | + {name: 'begin_mask', type: this.binding.TF_ATTR_INT, value: beginMask}, |
| 204 | + {name: 'end_mask', type: this.binding.TF_ATTR_INT, value: endMask}, |
| 205 | + {name: 'ellipsis_mask', type: this.binding.TF_ATTR_INT, value: 0}, |
| 206 | + {name: 'new_axis_mask', type: this.binding.TF_ATTR_INT, value: 0}, |
| 207 | + {name: 'shrink_axis_mask', type: this.binding.TF_ATTR_INT, value: 0} |
| 208 | + ]; |
| 209 | + return this.executeSingleOutput( |
| 210 | + 'StridedSlice', opAttrs, |
| 211 | + [x, beginTensor, endTensor, stridesTensor]) as T; |
| 212 | + } |
| 213 | + |
193 | 214 | slice<T extends Tensor>(x: T, begin: number[], size: number[]): T {
|
194 | 215 | const opAttrs = [
|
195 | 216 | this.createTypeOpAttr('T', x.dtype),
|
@@ -901,6 +922,18 @@ export class NodeJSKernelBackend implements KernelBackend {
|
901 | 922 | ]) as Tensor2D;
|
902 | 923 | }
|
903 | 924 |
|
| 925 | + cumsum(x: Tensor<Rank>, axis: number, exclusive: boolean, reverse: boolean): |
| 926 | + Tensor<Rank> { |
| 927 | + const axisTensor = scalar(axis, 'int32'); |
| 928 | + const opAttrs = [ |
| 929 | + {name: 'exclusive', type: this.binding.TF_ATTR_BOOL, value: exclusive}, |
| 930 | + {name: 'reverse', type: this.binding.TF_ATTR_BOOL, value: reverse}, |
| 931 | + this.createTypeOpAttr('T', x.dtype), |
| 932 | + this.createTypeOpAttr('Tidx', 'int32') |
| 933 | + ]; |
| 934 | + return this.executeSingleOutput('Cumsum', opAttrs, [x, axisTensor]); |
| 935 | + } |
| 936 | + |
904 | 937 | fromPixels(
|
905 | 938 | pixels: ImageData|HTMLImageElement|HTMLCanvasElement|HTMLVideoElement,
|
906 | 939 | numChannels: number): Tensor3D {
|
|
0 commit comments