diff --git a/src/nodejs_kernel_backend.ts b/src/nodejs_kernel_backend.ts index ba090f2a..fc20295c 100644 --- a/src/nodejs_kernel_backend.ts +++ b/src/nodejs_kernel_backend.ts @@ -120,14 +120,29 @@ export class NodeJSKernelBackend implements KernelBackend { return this.executeSingleOutput(name, opAttrs, [input]); } - private executeSingleOutput( - name: string, opAttrs: TFEOpAttr[], inputs: Tensor[]): Tensor { + /** + * Executes a TensorFlow Eager Op that provides one output Tensor. + * @param name The name of the Op to execute. + * @param opAttrs The list of Op attributes required to execute. + * @param inputs The list of input Tensors for the Op. + * @return A resulting Tensor from Op execution. + */ + executeSingleOutput(name: string, opAttrs: TFEOpAttr[], inputs: Tensor[]): + Tensor { const outputMetadata = this.binding.executeOp( name, opAttrs, this.getInputTensorIds(inputs), 1); return this.createOutputTensor(outputMetadata[0]); } - private executeMultipleOutputs( + /** + * Executes a TensorFlow Eager Op that provides multiple output Tensors. + * @param name The name of the Op to execute. + * @param opAttrs The list of Op attributes required to execute. + * @param inputs The list of input Tensors for the Op. + * @param numOutputs The number of output Tensors for Op execution. + * @return A resulting Tensor array from Op execution. + */ + executeMultipleOutputs( name: string, opAttrs: TFEOpAttr[], inputs: Tensor[], numOutputs: number): Tensor[] { const outputMetadata = this.binding.executeOp( diff --git a/src/ops/op_utils.ts b/src/ops/op_utils.ts new file mode 100644 index 00000000..0e5b53f5 --- /dev/null +++ b/src/ops/op_utils.ts @@ -0,0 +1,24 @@ +/** + * @license + * Copyright 2018 Google Inc. All Rights Reserved. + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + * ============================================================================= + */ + +import * as tfc from '@tensorflow/tfjs-core'; +import {NodeJSKernelBackend} from '../nodejs_kernel_backend'; + +/** Returns an instance of the Node JS backend. */ +export function nodeJSBackend(): NodeJSKernelBackend { + return (tfc.ENV.findBackend('tensorflow') as NodeJSKernelBackend); +} diff --git a/src/ops/op_utils_test.ts b/src/ops/op_utils_test.ts new file mode 100644 index 00000000..82650fca --- /dev/null +++ b/src/ops/op_utils_test.ts @@ -0,0 +1,25 @@ +/** + * @license + * Copyright 2018 Google Inc. All Rights Reserved. + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + * ============================================================================= + */ + +import {nodeJSBackend} from './op_utils'; + +describe('Exposes Backend for internal Op execution.', () => { + it('Provides the NodeJS backend over a function', () => { + const backend = nodeJSBackend(); + expect(backend).toBeDefined(); + }); +}); diff --git a/src/run_tests.ts b/src/run_tests.ts index 2591cca3..53e59958 100644 --- a/src/run_tests.ts +++ b/src/run_tests.ts @@ -15,7 +15,7 @@ * ============================================================================= */ -import './index'; +import '.'; import * as jasmine_util from '@tensorflow/tfjs-core/dist/jasmine_util'; Error.stackTraceLimit = Infinity; @@ -28,7 +28,9 @@ import bindings = require('bindings'); import {TFJSBinding} from './tfjs_binding'; import {NodeJSKernelBackend} from './nodejs_kernel_backend'; -process.on('unhandledRejection', e => { throw e; }); +process.on('unhandledRejection', e => { + throw e; +}); jasmine_util.setTestEnvs([{ name: 'test-tensorflow',