Skip to content
This repository was archived by the owner on Aug 15, 2019. It is now read-only.

Add tf.io.http and deprecate tf.io.browserHTTPRequest #1684

Merged
merged 7 commits into from
Apr 22, 2019
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
37 changes: 23 additions & 14 deletions src/io/browser_http.ts → src/io/http.ts
Original file line number Diff line number Diff line change
Expand Up @@ -29,7 +29,7 @@ import {loadWeightsAsArrayBuffer} from './weights_loader';

const OCTET_STREAM_MIME_TYPE = 'application/octet-stream';
const JSON_TYPE = 'application/json';
export class BrowserHTTPRequest implements IOHandler {
export class HTTPRequest implements IOHandler {
protected readonly path: string;
protected readonly requestInit: RequestInit;

Expand Down Expand Up @@ -62,14 +62,13 @@ export class BrowserHTTPRequest implements IOHandler {

assert(
path != null && path.length > 0,
() =>
'URL path for browserHTTPRequest must not be null, undefined or ' +
() => 'URL path for http must not be null, undefined or ' +
'empty.');

if (Array.isArray(path)) {
assert(
path.length === 2,
() => 'URL paths for browserHTTPRequest must have a length of 2, ' +
() => 'URL paths for http must have a length of 2, ' +
`(actual length is ${path.length}).`);
}
this.path = path;
Expand Down Expand Up @@ -135,7 +134,7 @@ export class BrowserHTTPRequest implements IOHandler {
/**
* Load model artifacts via HTTP request(s).
*
* See the documentation to `browserHTTPRequest` for details on the saved
* See the documentation to `tf.io.http` for details on the saved
* artifacts.
*
* @returns The loaded model artifacts (if loading succeeds).
Expand Down Expand Up @@ -236,14 +235,15 @@ export function parseUrl(url: string): [string, string] {
}

export function isHTTPScheme(url: string): boolean {
return url.match(BrowserHTTPRequest.URL_SCHEME_REGEX) != null;
return url.match(HTTPRequest.URL_SCHEME_REGEX) != null;
}

export const httpRequestRouter: IORouter =
export const httpRouter: IORouter =
(url: string, onProgress?: OnProgressCallback) => {
if (typeof fetch === 'undefined') {
// browserHTTPRequest uses `fetch`, if one wants to use it in node.js
// they have to setup a global fetch polyfill.
// `http` uses `fetch` or `node-fetch`, if one wants to use it in
// an environment that is not the browser or node they have to setup a
// global fetch polyfill.
return null;
} else {
let isHTTP = true;
Expand All @@ -253,13 +253,13 @@ export const httpRequestRouter: IORouter =
isHTTP = isHTTPScheme(url);
}
if (isHTTP) {
return browserHTTPRequest(url, {onProgress});
return http(url, {onProgress});
}
}
return null;
};
IORouterRegistry.registerSaveRouter(httpRequestRouter);
IORouterRegistry.registerLoadRouter(httpRequestRouter);
IORouterRegistry.registerSaveRouter(httpRouter);
IORouterRegistry.registerLoadRouter(httpRouter);

/**
* Creates an IOHandler subtype that sends model artifacts to HTTP server.
Expand All @@ -281,7 +281,7 @@ IORouterRegistry.registerLoadRouter(httpRequestRouter);
* model.add(
* tf.layers.dense({units: 1, inputShape: [100], activation: 'sigmoid'}));
*
* const saveResult = await model.save(tf.io.browserHTTPRequest(
* const saveResult = await model.save(tf.io.http(
* 'http://model-server:5000/upload', {method: 'PUT'}));
* console.log(saveResult);
* ```
Expand Down Expand Up @@ -325,7 +325,16 @@ IORouterRegistry.registerLoadRouter(httpRequestRouter);
* @returns An instance of `IOHandler`.
*/
/** @doc {heading: 'Models', subheading: 'Loading', namespace: 'io'} */
export function http(path: string, loadOptions?: LoadOptions): IOHandler {
return new HTTPRequest(path, loadOptions);
}

/**
* Deprecated. Use `tf.io.http`.
* @param path
* @param loadOptions
*/
export function browserHTTPRequest(
path: string, loadOptions?: LoadOptions): IOHandler {
return new BrowserHTTPRequest(path, loadOptions);
return http(path, loadOptions);
}
91 changes: 41 additions & 50 deletions src/io/browser_http_test.ts → src/io/http_test.ts
Original file line number Diff line number Diff line change
Expand Up @@ -15,13 +15,9 @@
* =============================================================================
*/

/**
* Unit tests for browser_http.ts.
*/

import * as tf from '../index';
import {BROWSER_ENVS, CHROME_ENVS, describeWithFlags, NODE_ENVS} from '../jasmine_util';
import {BrowserHTTPRequest, httpRequestRouter, parseUrl} from './browser_http';
import {HTTPRequest, httpRouter, parseUrl} from './http';

// Test data.
const modelTopology1: {} = {
Expand Down Expand Up @@ -57,7 +53,7 @@ const modelTopology1: {} = {
'backend': 'tensorflow'
};

let windowFetchSpy: jasmine.Spy;
let fetchSpy: jasmine.Spy;

type TypedArrays = Float32Array|Int32Array|Uint8Array|Uint16Array;
const fakeResponse =
Expand Down Expand Up @@ -85,22 +81,20 @@ const setupFakeWeightFiles =
}
},
requestInits: {[key: string]: RequestInit}) => {
windowFetchSpy =
// tslint:disable-next-line:no-any
spyOn(tf.util, 'fetch')
.and.callFake((path: string, init: RequestInit) => {
if (fileBufferMap[path]) {
requestInits[path] = init;
return Promise.resolve(fakeResponse(
fileBufferMap[path].data, fileBufferMap[path].contentType,
path));
} else {
return Promise.reject('path not found');
}
});
fetchSpy = spyOn(tf.util, 'fetch')
.and.callFake((path: string, init: RequestInit) => {
if (fileBufferMap[path]) {
requestInits[path] = init;
return Promise.resolve(fakeResponse(
fileBufferMap[path].data,
fileBufferMap[path].contentType, path));
} else {
return Promise.reject('path not found');
}
});
};

describeWithFlags('browserHTTPRequest-load fetch', NODE_ENVS, () => {
describeWithFlags('http-load fetch', NODE_ENVS, () => {
let requestInits: {[key: string]: {headers: {[key: string]: string}}};
// tslint:disable-next-line:no-any
let originalFetch: any;
Expand Down Expand Up @@ -149,7 +143,7 @@ describeWithFlags('browserHTTPRequest-load fetch', NODE_ENVS, () => {
},
requestInits);

const handler = tf.io.browserHTTPRequest('./model.json');
const handler = tf.io.http('./model.json');
const modelArtifacts = await handler.load();
expect(modelArtifacts.modelTopology).toEqual(modelTopology1);
expect(modelArtifacts.weightSpecs).toEqual(weightManifest1[0].weights);
Expand All @@ -160,7 +154,7 @@ describeWithFlags('browserHTTPRequest-load fetch', NODE_ENVS, () => {
// tslint:disable-next-line:no-any
delete (global as any).fetch;
try {
tf.io.browserHTTPRequest('./model.json');
tf.io.http('./model.json');
} catch (err) {
expect(err.message).toMatch(/Unable to find fetch polyfill./);
}
Expand All @@ -169,7 +163,7 @@ describeWithFlags('browserHTTPRequest-load fetch', NODE_ENVS, () => {

// Turned off for other browsers due to:
// https://github.com/tensorflow/tfjs/issues/426
describeWithFlags('browserHTTPRequest-save', CHROME_ENVS, () => {
describeWithFlags('http-save', CHROME_ENVS, () => {
// Test data.
const weightSpecs1: tf.io.WeightsManifestEntry[] = [
{
Expand Down Expand Up @@ -301,7 +295,7 @@ describeWithFlags('browserHTTPRequest-save', CHROME_ENVS, () => {

it('Save topology and weights, PUT method, extra headers', (done) => {
const testStartDate = new Date();
const handler = tf.io.browserHTTPRequest('model-upload-test', {
const handler = tf.io.http('model-upload-test', {
requestInit: {
method: 'PUT',
headers:
Expand Down Expand Up @@ -373,7 +367,7 @@ describeWithFlags('browserHTTPRequest-save', CHROME_ENVS, () => {
handler.save(artifacts1)
.then(saveResult => {
done.fail(
'Calling browserHTTPRequest at invalid URL succeeded ' +
'Calling http at invalid URL succeeded ' +
'unexpectedly');
})
.catch(err => {
Expand All @@ -384,33 +378,30 @@ describeWithFlags('browserHTTPRequest-save', CHROME_ENVS, () => {
it('getLoadHandlers with one URL string', () => {
const handlers = tf.io.getLoadHandlers('http://foo/model.json');
expect(handlers.length).toEqual(1);
expect(handlers[0] instanceof BrowserHTTPRequest).toEqual(true);
expect(handlers[0] instanceof HTTPRequest).toEqual(true);
});

it('Existing body leads to Error', () => {
expect(() => tf.io.browserHTTPRequest('model-upload-test', {
expect(() => tf.io.http('model-upload-test', {
requestInit: {body: 'existing body'}
})).toThrowError(/requestInit is expected to have no pre-existing body/);
});

it('Empty, null or undefined URL paths lead to Error', () => {
expect(() => tf.io.browserHTTPRequest(null))
expect(() => tf.io.http(null))
.toThrowError(/must not be null, undefined or empty/);
expect(() => tf.io.browserHTTPRequest(undefined))
expect(() => tf.io.http(undefined))
.toThrowError(/must not be null, undefined or empty/);
expect(() => tf.io.browserHTTPRequest(''))
expect(() => tf.io.http(''))
.toThrowError(/must not be null, undefined or empty/);
});

it('router', () => {
expect(httpRequestRouter('http://bar/foo') instanceof BrowserHTTPRequest)
.toEqual(true);
expect(
httpRequestRouter('https://localhost:5000/upload') instanceof
BrowserHTTPRequest)
expect(httpRouter('http://bar/foo') instanceof HTTPRequest).toEqual(true);
expect(httpRouter('https://localhost:5000/upload') instanceof HTTPRequest)
.toEqual(true);
expect(httpRequestRouter('localhost://foo')).toBeNull();
expect(httpRequestRouter('foo:5000/bar')).toBeNull();
expect(httpRouter('localhost://foo')).toBeNull();
expect(httpRouter('foo:5000/bar')).toBeNull();
});
});

Expand All @@ -435,7 +426,7 @@ describeWithFlags('parseUrl', BROWSER_ENVS, () => {
});
});

describeWithFlags('browserHTTPRequest-load', BROWSER_ENVS, () => {
describeWithFlags('http-load', BROWSER_ENVS, () => {
describe('JSON model', () => {
let requestInits: {[key: string]: {headers: {[key: string]: string}}};

Expand Down Expand Up @@ -474,14 +465,14 @@ describeWithFlags('browserHTTPRequest-load', BROWSER_ENVS, () => {
},
requestInits);

const handler = tf.io.browserHTTPRequest('./model.json');
const handler = tf.io.http('./model.json');
const modelArtifacts = await handler.load();
expect(modelArtifacts.modelTopology).toEqual(modelTopology1);
expect(modelArtifacts.weightSpecs).toEqual(weightManifest1[0].weights);
expect(new Float32Array(modelArtifacts.weightData)).toEqual(floatData);
expect(Object.keys(requestInits).length).toEqual(2);
// Assert that fetch is invoked with `window` as the context.
expect(windowFetchSpy.calls.mostRecent().object).toEqual(window);
expect(fetchSpy.calls.mostRecent().object).toEqual(window);
});

it('1 group, 2 weights, 1 path, with requestInit', async () => {
Expand Down Expand Up @@ -515,7 +506,7 @@ describeWithFlags('browserHTTPRequest-load', BROWSER_ENVS, () => {
},
requestInits);

const handler = tf.io.browserHTTPRequest(
const handler = tf.io.http(
'./model.json',
{requestInit: {headers: {'header_key_1': 'header_value_1'}}});
const modelArtifacts = await handler.load();
Expand All @@ -529,7 +520,7 @@ describeWithFlags('browserHTTPRequest-load', BROWSER_ENVS, () => {
expect(requestInits['./weightfile0'].headers['header_key_1'])
.toEqual('header_value_1');

expect(windowFetchSpy.calls.mostRecent().object).toEqual(window);
expect(fetchSpy.calls.mostRecent().object).toEqual(window);
});

it('1 group, 2 weight, 2 paths', async () => {
Expand Down Expand Up @@ -566,7 +557,7 @@ describeWithFlags('browserHTTPRequest-load', BROWSER_ENVS, () => {
},
requestInits);

const handler = tf.io.browserHTTPRequest('./model.json');
const handler = tf.io.http('./model.json');
const modelArtifacts = await handler.load();
expect(modelArtifacts.modelTopology).toEqual(modelTopology1);
expect(modelArtifacts.weightSpecs).toEqual(weightManifest1[0].weights);
Expand Down Expand Up @@ -609,7 +600,7 @@ describeWithFlags('browserHTTPRequest-load', BROWSER_ENVS, () => {
},
requestInits);

const handler = tf.io.browserHTTPRequest('./model.json');
const handler = tf.io.http('./model.json');
const modelArtifacts = await handler.load();
expect(modelArtifacts.modelTopology).toEqual(modelTopology1);
expect(modelArtifacts.weightSpecs)
Expand Down Expand Up @@ -654,7 +645,7 @@ describeWithFlags('browserHTTPRequest-load', BROWSER_ENVS, () => {
},
requestInits);

const handler = tf.io.browserHTTPRequest('path1/model.json');
const handler = tf.io.http('path1/model.json');
const modelArtifacts = await handler.load();
expect(modelArtifacts.modelTopology).toEqual(modelTopology1);
expect(modelArtifacts.weightSpecs)
Expand All @@ -676,7 +667,7 @@ describeWithFlags('browserHTTPRequest-load', BROWSER_ENVS, () => {
},
requestInits);

const handler = tf.io.browserHTTPRequest('./model.json');
const handler = tf.io.http('./model.json');
const modelArtifacts = await handler.load();
expect(modelArtifacts.modelTopology).toEqual(modelTopology1);
expect(modelArtifacts.weightSpecs).toBeUndefined();
Expand Down Expand Up @@ -717,7 +708,7 @@ describeWithFlags('browserHTTPRequest-load', BROWSER_ENVS, () => {
},
requestInits);

const handler = tf.io.browserHTTPRequest('path1/model.json');
const handler = tf.io.http('path1/model.json');
const modelArtifacts = await handler.load();
expect(modelArtifacts.modelTopology).toBeUndefined();
expect(modelArtifacts.weightSpecs)
Expand All @@ -737,7 +728,7 @@ describeWithFlags('browserHTTPRequest-load', BROWSER_ENVS, () => {
{data: JSON.stringify({}), contentType: 'application/json'}
},
requestInits);
const handler = tf.io.browserHTTPRequest('path1/model.json');
const handler = tf.io.http('path1/model.json');
handler.load()
.then(modelTopology1 => {
done.fail(
Expand All @@ -758,7 +749,7 @@ describeWithFlags('browserHTTPRequest-load', BROWSER_ENVS, () => {
{data: JSON.stringify({}), contentType: 'text/html'}
},
requestInits);
const handler = tf.io.browserHTTPRequest('path2/model.json');
const handler = tf.io.http('path2/model.json');
try {
const data = await handler.load();
expect(data).toBeDefined();
Expand Down Expand Up @@ -811,7 +802,7 @@ describeWithFlags('browserHTTPRequest-load', BROWSER_ENVS, () => {
}
}

const handler = tf.io.browserHTTPRequest(
const handler = tf.io.http(
'./model.json',
{requestInit: {credentials: 'include'}, fetchFunc: customFetch});
const modelArtifacts = await handler.load();
Expand Down
3 changes: 2 additions & 1 deletion src/io/io.ts
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,7 @@ import './indexed_db';
import './local_storage';

import {browserFiles} from './browser_files';
import {browserHTTPRequest, isHTTPScheme} from './browser_http';
import {browserHTTPRequest, http, isHTTPScheme} from './http';
import {concatenateArrayBuffers, decodeWeights, encodeWeights, getModelArtifactsInfoForJSON} from './io_utils';
import {fromMemory, withSaveHandler} from './passthrough';
import {getLoadHandlers, getSaveHandlers, registerLoadRouter, registerSaveRouter} from './router_registry';
Expand All @@ -39,6 +39,7 @@ export {
getLoadHandlers,
getModelArtifactsInfoForJSON,
getSaveHandlers,
http,
IOHandler,
isHTTPScheme,
LoadHandler,
Expand Down