Skip to content

ONNX Runtime improvements #1306

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Draft
wants to merge 8 commits into
base: main
Choose a base branch
from
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
227 changes: 69 additions & 158 deletions package-lock.json

Large diffs are not rendered by default.

8 changes: 4 additions & 4 deletions package.json
Original file line number Diff line number Diff line change
Expand Up @@ -27,7 +27,7 @@
"typegen": "tsc --build",
"dev": "webpack serve --no-client-overlay",
"build": "webpack && npm run typegen",
"test": "node --experimental-vm-modules --expose-gc node_modules/jest/bin/jest.js --verbose",
"test": "node --experimental-vm-modules --expose-gc node_modules/jest/bin/jest.js --verbose --logHeapUsage",
"readme": "python ./docs/scripts/build_readme.py",
"docs-api": "node ./docs/scripts/generate.js",
"docs-preview": "doc-builder preview transformers.js ./docs/source/ --not_python_module",
Expand Down Expand Up @@ -55,9 +55,9 @@
},
"homepage": "https://github.com/huggingface/transformers.js#readme",
"dependencies": {
"@huggingface/jinja": "^0.4.1",
"onnxruntime-node": "1.21.0",
"onnxruntime-web": "1.22.0-dev.20250409-89f8206ba4",
"@huggingface/jinja": "^0.5.0",
"onnxruntime-node": "1.23.0-dev.20250515-00bd398d54",
"onnxruntime-web": "1.23.0-dev.20250509-3dc91e6c31",
"sharp": "^0.34.1"
},
"devDependencies": {
Expand Down
26 changes: 19 additions & 7 deletions src/backends/onnx.js
Original file line number Diff line number Diff line change
Expand Up @@ -38,6 +38,7 @@ const DEVICE_TO_EXECUTION_PROVIDER_MAPPING = Object.freeze({
webgpu: 'webgpu', // WebGPU
cuda: 'cuda', // CUDA
dml: 'dml', // DirectML
coreml: 'coreml', // CoreML

webnn: { name: 'webnn', deviceType: 'cpu' }, // WebNN (default)
'webnn-npu': { name: 'webnn', deviceType: 'npu' }, // WebNN NPU
Expand All @@ -63,13 +64,15 @@ if (ORT_SYMBOL in globalThis) {
} else if (apis.IS_NODE_ENV) {
ONNX = ONNX_NODE.default ?? ONNX_NODE;

// Updated as of ONNX Runtime 1.20.1
// Updated as of ONNX Runtime 1.22.0-dev.20250418-c19a49615b
// The following table lists the supported versions of ONNX Runtime Node.js binding provided with pre-built binaries.
// | EPs/Platforms | Windows x64 | Windows arm64 | Linux x64 | Linux arm64 | MacOS x64 | MacOS arm64 |
// | ------------- | ----------- | ------------- | ----------------- | ----------- | --------- | ----------- |
// | CPU | ✔️ | ✔️ | ✔️ | ✔️ | ✔️ | ✔️ |
// | DirectML | ✔️ | ✔️ | ❌ | ❌ | ❌ | ❌ |
// | CUDA | ❌ | ❌ | ✔️ (CUDA v11.8) | ❌ | ❌ | ❌ |
// | EPs/Platforms | Windows x64 | Windows arm64 | Linux x64 | Linux arm64 | MacOS x64 | MacOS arm64 |
// | --------------------- | ------------------ | ------------------ | ------------------ | ------------------ | ------------------ | ------------------ |
// | CPU | ✔️ | ✔️ | ✔️ | ✔️ | ✔️ | ✔️ |
// | WebGPU (experimental) | ✔️ | ✔️ | ✔️ | ✔️ | ✔️ | ✔️ |
// | DirectML | ✔️ | ✔️ | ❌ | ❌ | ❌ | ❌ |
// | CUDA | ❌ | ❌ | ✔️ (CUDA v12) | ❌ | ❌ | ❌ |
// | CoreML | ❌ | ❌ | ❌ | ❌ | ✔️ | ✔️ |
switch (process.platform) {
case 'win32': // Windows x64 and Windows arm64
supportedDevices.push('dml');
Expand All @@ -80,9 +83,11 @@ if (ORT_SYMBOL in globalThis) {
}
break;
case 'darwin': // MacOS x64 and MacOS arm64
supportedDevices.push('coreml');
break;
}

supportedDevices.push('webgpu');
supportedDevices.push('cpu');
defaultDevices = ['cpu'];
} else {
Expand Down Expand Up @@ -180,9 +185,16 @@ if (ONNX_ENV?.wasm) {
if (
// @ts-ignore Cannot find name 'ServiceWorkerGlobalScope'.ts(2304)
!(typeof ServiceWorkerGlobalScope !== 'undefined' && self instanceof ServiceWorkerGlobalScope)
&& env.backends.onnx.versions?.web
&& !ONNX_ENV.wasm.wasmPaths
) {
ONNX_ENV.wasm.wasmPaths = `https://cdn.jsdelivr.net/npm/@huggingface/transformers@${env.version}/dist/`;
const wasmPathPrefix = `https://cdn.jsdelivr.net/npm/onnxruntime-web@${env.backends.onnx.versions.web}/dist/`;

ONNX_ENV.wasm.wasmPaths = apis.IS_SAFARI ? {
"mjs": `${wasmPathPrefix}/ort-wasm-simd-threaded.mjs`,
"wasm": `${wasmPathPrefix}/ort-wasm-simd-threaded.wasm`,
}
: wasmPathPrefix;
}

// TODO: Add support for loading WASM files from cached buffer when we upgrade to onnxruntime-web@1.19.0
Expand Down
33 changes: 32 additions & 1 deletion src/env.js
Original file line number Diff line number Diff line change
Expand Up @@ -30,11 +30,39 @@ const VERSION = '3.5.1';

// Check if various APIs are available (depends on environment)
const IS_BROWSER_ENV = typeof window !== "undefined" && typeof window.document !== "undefined";
const IS_WEBWORKER_ENV = typeof self !== "undefined" && self.constructor?.name === 'DedicatedWorkerGlobalScope';
const IS_WEBWORKER_ENV = typeof self !== "undefined" && self.constructor?.name === 'DedicatedWorkerGlobalScope';
const IS_WEB_CACHE_AVAILABLE = typeof self !== "undefined" && 'caches' in self;
const IS_WEBGPU_AVAILABLE = typeof navigator !== 'undefined' && 'gpu' in navigator;
const IS_WEBNN_AVAILABLE = typeof navigator !== 'undefined' && 'ml' in navigator;

/**
* Check if the current environment is Safari browser.
* Works in both browser and web worker contexts.
* @returns {boolean} Whether the current environment is Safari.
*/
const isSafari = () => {
// Check if we're in a browser environment
if (typeof navigator === 'undefined') {
return false;
}

const userAgent = navigator.userAgent;
const vendor = navigator.vendor || '';

// Safari has "Apple" in vendor string
const isAppleVendor = vendor.indexOf('Apple') > -1;

// Exclude Chrome on iOS (CriOS), Firefox on iOS (FxiOS),
// Edge on iOS (EdgiOS), and other browsers
const notOtherBrowser =
!userAgent.match(/CriOS|FxiOS|EdgiOS|OPiOS|mercury|brave/i) &&
!userAgent.includes('Chrome') &&
!userAgent.includes('Android');

return isAppleVendor && notOtherBrowser;
};
const IS_SAFARI = isSafari();

const IS_PROCESS_AVAILABLE = typeof process !== 'undefined';
const IS_NODE_ENV = IS_PROCESS_AVAILABLE && process?.release?.name === 'node';
const IS_FS_AVAILABLE = !isEmpty(fs);
Expand All @@ -59,6 +87,9 @@ export const apis = Object.freeze({
/** Whether the WebNN API is available */
IS_WEBNN_AVAILABLE,

/** Whether we are running in a Safari browser */
IS_SAFARI,

/** Whether the Node.js process API is available */
IS_PROCESS_AVAILABLE,

Expand Down
26 changes: 14 additions & 12 deletions src/models.js
Original file line number Diff line number Diff line change
Expand Up @@ -154,10 +154,11 @@ const MODEL_CLASS_TO_NAME_MAPPING = new Map();
* @param {string} pretrained_model_name_or_path The path to the directory containing the model file.
* @param {string} fileName The name of the model file.
* @param {import('./utils/hub.js').PretrainedModelOptions} options Additional options for loading the model.
* @param {boolean} [is_decoder=false] Whether the model is a decoder model.
* @returns {Promise<{buffer_or_path: Uint8Array|string, session_options: Object, session_config: Object}>} A Promise that resolves to the data needed to create an InferenceSession object.
* @private
*/
async function getSession(pretrained_model_name_or_path, fileName, options) {
async function getSession(pretrained_model_name_or_path, fileName, options, is_decoder=false) {
let custom_config = options.config?.['transformers.js_config'] ?? {};

let device = options.device ?? custom_config.device;
Expand Down Expand Up @@ -316,7 +317,7 @@ async function getSession(pretrained_model_name_or_path, fileName, options) {
}
}

if (selectedDevice === 'webgpu') {
if (is_decoder && selectedDevice === 'webgpu') {
const shapes = getKeyValueShapes(options.config, {
prefix: 'present',
});
Expand All @@ -342,13 +343,14 @@ async function getSession(pretrained_model_name_or_path, fileName, options) {
* @param {string} pretrained_model_name_or_path The path to the directory containing the model file.
* @param {Record<string, string>} names The names of the model files to load.
* @param {import('./utils/hub.js').PretrainedModelOptions} options Additional options for loading the model.
* @param {string} [decoder_name] The name of the decoder model, if any.
* @returns {Promise<Record<string, any>>} A Promise that resolves to a dictionary of InferenceSession objects.
* @private
*/
async function constructSessions(pretrained_model_name_or_path, names, options) {
async function constructSessions(pretrained_model_name_or_path, names, options, decoder_name=undefined) {
return Object.fromEntries(await Promise.all(
Object.keys(names).map(async (name) => {
const { buffer_or_path, session_options, session_config } = await getSession(pretrained_model_name_or_path, names[name], options);
const { buffer_or_path, session_options, session_config } = await getSession(pretrained_model_name_or_path, names[name], options, name === decoder_name);
const session = await createInferenceSession(buffer_or_path, session_options, session_config);
return [name, session];
})
Expand Down Expand Up @@ -1148,7 +1150,7 @@ export class PreTrainedModel extends Callable {
info = await Promise.all([
constructSessions(pretrained_model_name_or_path, {
model: options.model_file_name ?? 'model',
}, options),
}, options, 'model'),
getOptionalConfigs(pretrained_model_name_or_path, {
generation_config: 'generation_config.json',
}, options),
Expand All @@ -1159,7 +1161,7 @@ export class PreTrainedModel extends Callable {
constructSessions(pretrained_model_name_or_path, {
model: 'encoder_model',
decoder_model_merged: 'decoder_model_merged',
}, options),
}, options, 'decoder_model_merged'),
getOptionalConfigs(pretrained_model_name_or_path, {
generation_config: 'generation_config.json',
}, options),
Expand All @@ -1178,7 +1180,7 @@ export class PreTrainedModel extends Callable {
constructSessions(pretrained_model_name_or_path, {
model: 'encoder_model',
decoder_model_merged: 'decoder_model_merged',
}, options),
}, options, 'decoder_model_merged'),
]);

} else if (modelType === MODEL_TYPES.ImageTextToText) {
Expand All @@ -1191,7 +1193,7 @@ export class PreTrainedModel extends Callable {
sessions['model'] = 'encoder_model';
}
info = await Promise.all([
constructSessions(pretrained_model_name_or_path, sessions, options),
constructSessions(pretrained_model_name_or_path, sessions, options, 'decoder_model_merged'),
getOptionalConfigs(pretrained_model_name_or_path, {
generation_config: 'generation_config.json',
}, options),
Expand All @@ -1204,7 +1206,7 @@ export class PreTrainedModel extends Callable {
decoder_model_merged: 'decoder_model_merged',
}
info = await Promise.all([
constructSessions(pretrained_model_name_or_path, sessions, options),
constructSessions(pretrained_model_name_or_path, sessions, options, 'decoder_model_merged'),
getOptionalConfigs(pretrained_model_name_or_path, {
generation_config: 'generation_config.json',
}, options),
Expand All @@ -1216,7 +1218,7 @@ export class PreTrainedModel extends Callable {
model: 'text_encoder',
decoder_model_merged: 'decoder_model_merged',
encodec_decode: 'encodec_decode',
}, options),
}, options, 'decoder_model_merged'),
getOptionalConfigs(pretrained_model_name_or_path, {
generation_config: 'generation_config.json',
}, options),
Expand All @@ -1231,7 +1233,7 @@ export class PreTrainedModel extends Callable {
gen_head: 'gen_head',
gen_img_embeds: 'gen_img_embeds',
image_decode: 'image_decode',
}, options),
}, options, 'model'),
getOptionalConfigs(pretrained_model_name_or_path, {
generation_config: 'generation_config.json',
}, options),
Expand All @@ -1243,7 +1245,7 @@ export class PreTrainedModel extends Callable {
prepare_inputs_embeds: 'prepare_inputs_embeds',
model: 'model',
vision_encoder: 'vision_encoder',
}, options),
}, options, 'model'),
getOptionalConfigs(pretrained_model_name_or_path, {
generation_config: 'generation_config.json',
}, options),
Expand Down
1 change: 1 addition & 0 deletions src/utils/devices.js
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,7 @@ export const DEVICE_TYPES = Object.freeze({
webgpu: 'webgpu', // WebGPU
cuda: 'cuda', // CUDA
dml: 'dml', // DirectML
coreml: 'coreml', // CoreML

webnn: 'webnn', // WebNN (default)
'webnn-npu': 'webnn-npu', // WebNN NPU
Expand Down
6 changes: 3 additions & 3 deletions tests/models/florence2/test_modeling_florence2.js
Original file line number Diff line number Diff line change
Expand Up @@ -46,7 +46,7 @@ export default () => {
{
const inputs = await processor(image, texts[0]);
const generate_ids = await model.generate({ ...inputs, max_new_tokens: 10 });
expect(generate_ids.tolist()).toEqual([[2n, 0n, 48n, 48n, 48n, 48n, 48n, 48n, 48n, 48n, 2n]]);
expect(generate_ids.tolist()).toEqual([[2n, 0n, 0n, 0n, 1n, 0n, 0n, 2n]]);
}
},
MAX_TEST_EXECUTION_TIME,
Expand All @@ -68,8 +68,8 @@ export default () => {

const generate_ids = await model.generate({ ...inputs, max_new_tokens: 10 });
expect(generate_ids.tolist()).toEqual([
[2n, 0n, 48n, 48n, 48n, 48n, 48n, 48n, 48n, 48n, 2n],
[2n, 0n, 48n, 48n, 48n, 48n, 48n, 48n, 48n, 48n, 2n],
[2n, 0n, 0n, 0n, 1n, 0n, 0n, 2n],
[2n, 0n, 0n, 0n, 1n, 0n, 0n, 2n],
]);
}
},
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -32,7 +32,7 @@ export default () => {
expect(pred_boxes.dims).toEqual([1, num_queries, 4]);
expect(logits.max().item()).toBeCloseTo(56.237613677978516, 2);
expect(logits.min().item()).toEqual(-Infinity);
expect(pred_boxes.mean().item()).toEqual(0.2500016987323761);
expect(pred_boxes.mean().item()).toBeCloseTo(0.2500016987323761, 4);
},
MAX_TEST_EXECUTION_TIME,
);
Expand Down
Loading
Loading