Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
22 changes: 11 additions & 11 deletions src/backends/onnx.js
Original file line number Diff line number Diff line change
Expand Up @@ -22,8 +22,7 @@ import { env, apis } from '../env.js';
// In either case, we select the default export if it exists, otherwise we use the named export.
import * as ONNX_NODE from 'onnxruntime-node';
import * as ONNX_WEB from 'onnxruntime-web/webgpu';
import { loadWasmBinary, loadWasmFactory } from './utils/cacheWasm.js';

import { isBlobURL, loadWasmBinary, loadWasmFactory, toAbsoluteURL } from './utils/cacheWasm.js';
export { Tensor } from 'onnxruntime-common';

/**
Expand Down Expand Up @@ -186,10 +185,10 @@ async function ensureWasmLoaded() {
// Load and cache both the WASM binary and factory
await Promise.all([
// Load and cache the WASM binary
urls.wasm
urls.wasm && !isBlobURL(urls.wasm)
? (async () => {
try {
const wasmBinary = await loadWasmBinary(urls.wasm);
const wasmBinary = await loadWasmBinary(toAbsoluteURL(urls.wasm));
if (wasmBinary) {
ONNX_ENV.wasm.wasmBinary = wasmBinary;
}
Expand All @@ -200,10 +199,10 @@ async function ensureWasmLoaded() {
: Promise.resolve(),

// Load and cache the WASM factory
urls.mjs
urls.mjs && !isBlobURL(urls.mjs)
? (async () => {
try {
const wasmFactoryBlob = await loadWasmFactory(urls.mjs);
const wasmFactoryBlob = await loadWasmFactory(toAbsoluteURL(urls.mjs));
if (wasmFactoryBlob) {
// @ts-ignore
ONNX_ENV.wasm.wasmPaths.mjs = wasmFactoryBlob;
Expand All @@ -228,11 +227,12 @@ async function ensureWasmLoaded() {
*/
export async function createInferenceSession(buffer_or_path, session_options, session_config) {
await ensureWasmLoaded();
const load = () => InferenceSession.create(buffer_or_path, {
// Set default log level, but allow overriding through session options
logSeverityLevel: DEFAULT_LOG_LEVEL,
...session_options,
});
const load = () =>
InferenceSession.create(buffer_or_path, {
// Set default log level, but allow overriding through session options
logSeverityLevel: DEFAULT_LOG_LEVEL,
...session_options,
});
const session = await (IS_WEB_ENV ? (webInitChain = webInitChain.then(load)) : load());
session.config = session_config;
return session;
Expand Down
36 changes: 35 additions & 1 deletion src/backends/utils/cacheWasm.js
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
import { getCache } from '../../utils/cache.js';
import { isValidUrl } from '../../utils/hub/utils.js';

/**
* Loads and caches a file from the given URL.
Expand Down Expand Up @@ -41,7 +42,6 @@ async function loadAndCacheFile(url) {
}

return response;

}

/**
Expand Down Expand Up @@ -83,3 +83,37 @@ export async function loadWasmFactory(libURL) {
return null;
}
}

/**
* Checks if the given URL is a blob URL (created via URL.createObjectURL).
* Blob URLs should not be cached as they are temporary in-memory references.
* @param {string} url - The URL to check.
* @returns {boolean} True if the URL is a blob URL, false otherwise.
*/
export function isBlobURL(url) {
return isValidUrl(url, ['blob:']);
}

/**
* Converts any URL to an absolute URL if needed.
* If the URL is already absolute (http://, https://, or blob:), returns it unchanged (handled by new URL(...)).
* Otherwise, resolves it relative to the current page location (browser) or module location (Node/Bun/Deno).
* @param {string} url - The URL to convert (can be relative or absolute).
* @returns {string} The absolute URL.
*/
export function toAbsoluteURL(url) {
let baseURL;

if (typeof location !== 'undefined' && location.href) {
// Browser environment: use location.href
baseURL = location.href;
} else if (typeof import.meta !== 'undefined' && import.meta.url) {
// Node.js/Bun/Deno module environment: use import.meta.url
baseURL = import.meta.url;
} else {
// Fallback: if no base is available, return the URL unchanged
return url;
}

return new URL(url, baseURL).href;
}
Loading