|
18 | 18 | import * as tf from '../index';
|
19 | 19 | import {BROWSER_ENVS, CHROME_ENVS, describeWithFlags, NODE_ENVS} from '../jasmine_util';
|
20 | 20 | import {HTTPRequest, httpRouter, parseUrl} from './http';
|
| 21 | +import {CompositeArrayBuffer} from './composite_array_buffer'; |
21 | 22 |
|
22 | 23 | // Test data.
|
23 | 24 | const modelTopology1: {} = {
|
@@ -161,7 +162,8 @@ describeWithFlags('http-load fetch', NODE_ENVS, () => {
|
161 | 162 | expect(modelArtifacts.generatedBy).toEqual('1.15');
|
162 | 163 | expect(modelArtifacts.convertedBy).toEqual('1.3.1');
|
163 | 164 | expect(modelArtifacts.userDefinedMetadata).toEqual({});
|
164 |
| - expect(new Float32Array(modelArtifacts.weightData)).toEqual(floatData); |
| 165 | + expect(new Float32Array(CompositeArrayBuffer.join( |
| 166 | + modelArtifacts.weightData))).toEqual(floatData); |
165 | 167 | });
|
166 | 168 |
|
167 | 169 | it('throw exception if no fetch polyfill', () => {
|
@@ -507,7 +509,8 @@ describeWithFlags('http-load', BROWSER_ENVS, () => {
|
507 | 509 | expect(modelArtifacts.userDefinedMetadata).toEqual({});
|
508 | 510 | expect(modelArtifacts.modelInitializer).toEqual({});
|
509 | 511 |
|
510 |
| - expect(new Float32Array(modelArtifacts.weightData)).toEqual(floatData); |
| 512 | + expect(new Float32Array(CompositeArrayBuffer.join(modelArtifacts |
| 513 | + .weightData))).toEqual(floatData); |
511 | 514 | expect(Object.keys(requestInits).length).toEqual(2);
|
512 | 515 | // Assert that fetch is invoked with `window` as the context.
|
513 | 516 | expect(fetchSpy.calls.mostRecent().object).toEqual(window);
|
@@ -550,7 +553,8 @@ describeWithFlags('http-load', BROWSER_ENVS, () => {
|
550 | 553 | const modelArtifacts = await handler.load();
|
551 | 554 | expect(modelArtifacts.modelTopology).toEqual(modelTopology1);
|
552 | 555 | expect(modelArtifacts.weightSpecs).toEqual(weightManifest1[0].weights);
|
553 |
| - expect(new Float32Array(modelArtifacts.weightData)).toEqual(floatData); |
| 556 | + expect(new Float32Array(CompositeArrayBuffer.join(modelArtifacts |
| 557 | + .weightData))).toEqual(floatData); |
554 | 558 | expect(Object.keys(requestInits).length).toEqual(2);
|
555 | 559 | expect(Object.keys(requestInits).length).toEqual(2);
|
556 | 560 | expect(requestInits['./model.json'].headers['header_key_1'])
|
@@ -599,8 +603,8 @@ describeWithFlags('http-load', BROWSER_ENVS, () => {
|
599 | 603 | const modelArtifacts = await handler.load();
|
600 | 604 | expect(modelArtifacts.modelTopology).toEqual(modelTopology1);
|
601 | 605 | expect(modelArtifacts.weightSpecs).toEqual(weightManifest1[0].weights);
|
602 |
| - expect(new Float32Array(modelArtifacts.weightData)) |
603 |
| - .toEqual(new Float32Array([1, 3, 3, 7, 4])); |
| 606 | + expect(new Float32Array(CompositeArrayBuffer.join(modelArtifacts |
| 607 | + .weightData))).toEqual(new Float32Array([1, 3, 3, 7, 4])); |
604 | 608 | });
|
605 | 609 |
|
606 | 610 | it('2 groups, 2 weight, 2 paths', async () => {
|
@@ -644,8 +648,9 @@ describeWithFlags('http-load', BROWSER_ENVS, () => {
|
644 | 648 | expect(modelArtifacts.weightSpecs)
|
645 | 649 | .toEqual(
|
646 | 650 | weightsManifest[0].weights.concat(weightsManifest[1].weights));
|
647 |
| - expect(new Float32Array(modelArtifacts.weightData)) |
648 |
| - .toEqual(new Float32Array([1, 3, 3, 7, 4])); |
| 651 | + expect(new Float32Array(CompositeArrayBuffer.join( |
| 652 | + modelArtifacts.weightData))) |
| 653 | + .toEqual(new Float32Array([1, 3, 3, 7, 4])); |
649 | 654 | });
|
650 | 655 |
|
651 | 656 | it('2 groups, 2 weight, 2 paths, Int32 and Uint8 Data', async () => {
|
@@ -689,10 +694,10 @@ describeWithFlags('http-load', BROWSER_ENVS, () => {
|
689 | 694 | expect(modelArtifacts.weightSpecs)
|
690 | 695 | .toEqual(
|
691 | 696 | weightsManifest[0].weights.concat(weightsManifest[1].weights));
|
692 |
| - expect(new Int32Array(modelArtifacts.weightData.slice(0, 12))) |
693 |
| - .toEqual(new Int32Array([1, 3, 3])); |
694 |
| - expect(new Uint8Array(modelArtifacts.weightData.slice(12, 14))) |
695 |
| - .toEqual(new Uint8Array([7, 4])); |
| 697 | + expect(new Int32Array(CompositeArrayBuffer.join(modelArtifacts.weightData) |
| 698 | + .slice(0, 12))).toEqual(new Int32Array([1, 3, 3])); |
| 699 | + expect(new Uint8Array(CompositeArrayBuffer.join(modelArtifacts.weightData) |
| 700 | + .slice(12, 14))).toEqual(new Uint8Array([7, 4])); |
696 | 701 | });
|
697 | 702 |
|
698 | 703 | it('topology only', async () => {
|
@@ -752,10 +757,11 @@ describeWithFlags('http-load', BROWSER_ENVS, () => {
|
752 | 757 | expect(modelArtifacts.weightSpecs)
|
753 | 758 | .toEqual(
|
754 | 759 | weightsManifest[0].weights.concat(weightsManifest[1].weights));
|
755 |
| - expect(new Int32Array(modelArtifacts.weightData.slice(0, 12))) |
756 |
| - .toEqual(new Int32Array([1, 3, 3])); |
757 |
| - expect(new Float32Array(modelArtifacts.weightData.slice(12, 20))) |
758 |
| - .toEqual(new Float32Array([-7, -4])); |
| 760 | + expect(new Int32Array(CompositeArrayBuffer.join(modelArtifacts.weightData) |
| 761 | + .slice(0, 12))).toEqual(new Int32Array([1, 3, 3])); |
| 762 | + expect(new Float32Array(CompositeArrayBuffer |
| 763 | + .join(modelArtifacts.weightData) |
| 764 | + .slice(12, 20))).toEqual(new Float32Array([-7, -4])); |
759 | 765 | });
|
760 | 766 |
|
761 | 767 | it('Missing modelTopology and weightsManifest leads to error', async () => {
|
@@ -840,7 +846,8 @@ describeWithFlags('http-load', BROWSER_ENVS, () => {
|
840 | 846 | const modelArtifacts = await handler.load();
|
841 | 847 | expect(modelArtifacts.modelTopology).toEqual(modelTopology1);
|
842 | 848 | expect(modelArtifacts.weightSpecs).toEqual(weightManifest1[0].weights);
|
843 |
| - expect(new Float32Array(modelArtifacts.weightData)).toEqual(floatData); |
| 849 | + expect(new Float32Array(CompositeArrayBuffer.join( |
| 850 | + modelArtifacts.weightData))).toEqual(floatData); |
844 | 851 | expect(Object.keys(requestInits).length).toEqual(2);
|
845 | 852 | expect(Object.keys(requestInits).length).toEqual(2);
|
846 | 853 | expect(requestInits['./model.json'].headers['header_key_1'])
|
@@ -902,7 +909,8 @@ describeWithFlags('http-load', BROWSER_ENVS, () => {
|
902 | 909 | expect(modelArtifacts.modelTopology).toEqual(modelTopology1);
|
903 | 910 | expect(modelArtifacts.trainingConfig).toEqual(trainingConfig1);
|
904 | 911 | expect(modelArtifacts.weightSpecs).toEqual(weightManifest1[0].weights);
|
905 |
| - expect(new Float32Array(modelArtifacts.weightData)).toEqual(floatData); |
| 912 | + expect(new Float32Array(CompositeArrayBuffer |
| 913 | + .join(modelArtifacts.weightData))).toEqual(floatData); |
906 | 914 |
|
907 | 915 | expect(fetchInputs).toEqual(['./model.json', './weightfile0']);
|
908 | 916 | expect(fetchInits.length).toEqual(2);
|
|
0 commit comments