Skip to content

Commit

Permalink
refactor: optimize
Browse files Browse the repository at this point in the history
  • Loading branch information
mys1024 committed Mar 9, 2024
1 parent 578e3e8 commit 762d90f
Show file tree
Hide file tree
Showing 7 changed files with 113 additions and 99 deletions.
21 changes: 11 additions & 10 deletions src/define.ts
Original file line number Diff line number Diff line change
@@ -1,11 +1,10 @@
import type { AnyFn, DefineWorkerFnOpts, InternalFns } from "./types.ts";
import type { MsgPort, MsgPortNormalized } from "./rpc/types.ts";
import { getGlobal } from "./rpc/utils.ts";
import { RpcAgent } from "./rpc/rpc.ts";

/* -------------------------------------------------- common -------------------------------------------------- */

const _global = getGlobal() as MsgPortNormalized;
const _global = Function("return this")() as MsgPortNormalized;

/* -------------------------------------------------- defineWorkerFn() -------------------------------------------------- */

Expand All @@ -16,9 +15,9 @@ const _global = getGlobal() as MsgPortNormalized;
* @param fn - The worker function.
* @param options - An object containing options.
*/
export function defineWorkerFn<FN extends AnyFn>(
export function defineWorkerFn(
name: string,
fn: FN,
fn: AnyFn,
options: DefineWorkerFnOpts = {},
): void {
const { transfer, port = _global } = options;
Expand All @@ -35,7 +34,7 @@ export function defineWorkerFn<FN extends AnyFn>(
/**
* Invoke this function in worker threads to define worker functions.
*
* @param functions - An object containing worker functions. Keys will be used as the name of the worker functions.
* @param functions - An object containing worker functions. Keys will be used as the names of the worker functions.
* @param options - An object containing options.
*/
export function defineWorkerFns(
Expand All @@ -49,24 +48,26 @@ export function defineWorkerFns(

/* -------------------------------------------------- internal functions -------------------------------------------------- */

const INTERNAL_FNS_DEFINED_SYM = Symbol("INTERNAL_FNS_DEFINED_SYM");
const INTERNAL_FNS_DEFINED = Symbol("internalFnsDefined");

/**
* Ensure that internal functions are defined.
*/
function ensureInternalFns(port: MsgPort) {
// prevent double usage
if ((port as any)[INTERNAL_FNS_DEFINED_SYM]) {
if ((port as any)[INTERNAL_FNS_DEFINED]) {
return;
}
(port as any)[INTERNAL_FNS_DEFINED_SYM] = true;
(port as any)[INTERNAL_FNS_DEFINED] = true;
// get RpcAgent instance
const rpcAgent = RpcAgent.getRpcAgent(port);
// internal functions
// define internal functions
const internalFns: InternalFns = {
/**
* @returns Names of defined worker functions
*/
names: () => rpcAgent.getLocalFnNames("fn"),
};
// define internal functions
for (const [name, fn] of Object.entries(internalFns)) {
rpcAgent.defineLocalFn(name, fn, { namespace: "fn-internal" });
}
Expand Down
41 changes: 23 additions & 18 deletions src/rpc/rpc.ts
Original file line number Diff line number Diff line change
Expand Up @@ -14,10 +14,11 @@ import {
} from "./utils.ts";

const DEFAULT_NAMESPACE = "rpc";
const RPC_AGENT_SYM = Symbol("RPC_AGENT_SYM");
const RPC_AGENT = Symbol("rpcAgent");

export class RpcAgent {
#msgPort: MsgPortNormalized;
#callCount = 0;

/** namespace -> name -> fnConf */
#localFns = new Map<
Expand All @@ -38,17 +39,17 @@ export class RpcAgent {
>();

static getRpcAgent(msgPort: MsgPort): RpcAgent {
return (msgPort as any)[RPC_AGENT_SYM] || new RpcAgent(msgPort);
return (msgPort as any)[RPC_AGENT] || new RpcAgent(msgPort);
}

constructor(msgPort: MsgPort) {
// prevent double usage
if ((msgPort as any)[RPC_AGENT_SYM]) {
if ((msgPort as any)[RPC_AGENT]) {
throw new Error(
"The MsgPort has already been used by another RpcAgent instance, invoke `RpcAgent.getRpcAgent()` to get that RpcAgent instance instead.",
);
}
(msgPort as any)[RPC_AGENT_SYM] = this;
(msgPort as any)[RPC_AGENT] = this;
// init properties
this.#msgPort = toMsgPortNormalized(msgPort);
// start listening to messages
Expand Down Expand Up @@ -81,17 +82,15 @@ export class RpcAgent {

const keyCallMap = this.#getKeyCallMap(namespace, true);

const key = Math.random();
const key = ++this.#callCount;
const ret = new Promise<AwaitedRet<FN>>((resolve, reject) => {
keyCallMap.set(key, { resolve, reject });
});

this.#sendCallMsg({
meta: {
ns: namespace,
name,
key,
},
ns: namespace,
name,
key,
type: "call",
args,
}, {
Expand Down Expand Up @@ -143,14 +142,15 @@ export class RpcAgent {
#startListening() {
this.#msgPort.addEventListener("message", async (event) => {
if (isRpcCallMsg(event.data)) {
const { meta, args } = event.data;
const { ns, name } = meta;
const { ns, name, key, args } = event.data;
// get the local function
const nameFnMap = this.#getNameFnMap(ns, false);
if (!nameFnMap) {
this.#sendReturnMsg({
meta,
type: "return",
ns,
name,
key,
ok: false,
err: new Error(`The namespace "${ns}" is not defined.`),
});
Expand All @@ -159,8 +159,10 @@ export class RpcAgent {
const fnConf = nameFnMap.get(name);
if (!fnConf) {
this.#sendReturnMsg({
meta,
type: "return",
ns,
name,
key,
ok: false,
err: new Error(
`The name "${name}" is not defined in namespace "${ns}".`,
Expand All @@ -173,24 +175,27 @@ export class RpcAgent {
try {
const ret = await fn(...args);
this.#sendReturnMsg({
meta,
type: "return",
ns,
name,
key,
ok: true,
ret,
}, {
transfer: transfer && isTransferable(ret) ? [ret] : undefined,
});
} catch (err) {
this.#sendReturnMsg({
meta,
type: "return",
ns,
name,
key,
ok: false,
err,
});
}
} else if (isRpcReturnMsg(event.data)) {
const { meta, ok, ret, err } = event.data;
const { ns, name, key } = meta;
const { ns, name, key, ok, ret, err } = event.data;
// get the promise resolvers
const keyCallMap = this.#getKeyCallMap(ns, false);
if (!keyCallMap) {
Expand Down
21 changes: 14 additions & 7 deletions src/rpc/types.ts
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,9 @@ export type AwaitedRet<FN extends AnyFn> = Awaited<ReturnType<FN>>;

/* -------------------------------------------------- msg -------------------------------------------------- */

/**
* Worker
*/
export interface MsgPortNormalized {
postMessage(
message: any,
Expand All @@ -19,6 +22,9 @@ export interface MsgPortNormalized {
) => void;
}

/**
* node:worker_threads
*/
export interface MsgPortNode {
postMessage(value: any): void;
on(event: "message", listener: (value: any) => void): void;
Expand All @@ -28,20 +34,21 @@ export type MsgPort = MsgPortNormalized | MsgPortNode;

/* -------------------------------------------------- RPC -------------------------------------------------- */

export interface RpcMeta {
export type RpcCallMsg<FN extends AnyFn = AnyFn> = {
type: "call";
ns: string;
name: string;
key: number;
}

export type RpcCallMsg<FN extends AnyFn = AnyFn> = {
meta: RpcMeta;
type: "call";
args: Parameters<FN>;
};

export type RpcReturnMsg<FN extends AnyFn = AnyFn> =
& { meta: RpcMeta; type: "return" }
& {
type: "return";
ns: string;
name: string;
key: number;
}
& (
| { ok: true; ret: Awaited<ReturnType<FN>>; err?: undefined }
| { ok: false; ret?: undefined; err: any }
Expand Down
49 changes: 21 additions & 28 deletions src/rpc/utils.ts
Original file line number Diff line number Diff line change
Expand Up @@ -5,39 +5,32 @@ import {
RpcReturnMsg,
} from "./types.ts";

const _global = getGlobal();

/**
* @see https://developer.mozilla.org/en-US/docs/Web/API/Web_Workers_API/Transferable_objects#supported_objects
*/
const transferableClasses = [
_global.ArrayBuffer,
_global.MessagePort,
_global.ReadableStream,
_global.WritableStream,
_global.TransformStream,
_global.WebTransportReceiveStream,
_global.WebTransportSendStream,
_global.AudioData,
_global.ImageBitmap,
_global.VideoFrame,
_global.OffscreenCanvas,
_global.RTCDataChannel,
].filter((c) => !!c);
const transferableClasses = (function () {
return [
"ArrayBuffer",
"MessagePort",
"ReadableStream",
"WritableStream",
"TransformStream",
"WebTransportReceiveStream",
"WebTransportSendStream",
"AudioData",
"ImageBitmap",
"VideoFrame",
"OffscreenCanvas",
"RTCDataChannel",
].map((name) => getGlobalVar(name)).filter((v) => !!v);
})();

export function getGlobal() {
const globalNames = ["globalThis", "self", "window", "global"];
for (const name of globalNames) {
try {
const g = eval(name);
if (g) {
return g;
}
} catch {
continue;
}
export function getGlobalVar(name: string) {
try {
return Function(`return ${name}`)();
} catch {
return undefined;
}
throw new Error("Failed to get the global object.");
}

export function isTransferable(val: any): val is Transferable {
Expand Down
29 changes: 24 additions & 5 deletions src/types.ts
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,20 @@ import type { MsgPort } from "./rpc/types.ts";

export type AnyFn = (...args: any[]) => any;

export type AwaitedRet<FN extends AnyFn> = Awaited<ReturnType<FN>>;

export type ProxyFn<FN extends AnyFn> = (
...args: Parameters<FN>
) => Promise<AwaitedRet<FN>>;

export type ProxyFns<FNS extends Record<string, AnyFn>> = {
[P in keyof FNS]: ProxyFn<FNS[P]>;
};

export type InternalFns = {
names: () => string[];
};

export interface DefineWorkerFnOpts {
/**
* Whether to transfer the return value of worker function if it is of type `Transferable`.
Expand All @@ -11,6 +25,15 @@ export interface DefineWorkerFnOpts {
*/
transfer?: boolean;

/**
* The message port to communicate with the main thread.
*
* @default self
* @example
* // In Node.js
* import { parentPort } from "node:worker_threads";
* defineWorkerFn("add", add, { port: parentPort! });
*/
port?: MsgPort;
}

Expand All @@ -19,12 +42,8 @@ export interface UseWorkerFnOpts<FN extends AnyFn> {
* A function that determines objects to be transferred when posting messages to the worker thread.
*
* @see https://developer.mozilla.org/en-US/docs/Web/API/Worker/postMessage#transfer
* @param ctx The context of proxy function invocation.
* @param ctx The context of the function call.
* @returns Transferable objects.
*/
transfer?: (ctx: { args: Parameters<FN> }) => Transferable[];
}

export type InternalFns = {
names: () => string[];
};
26 changes: 10 additions & 16 deletions src/use.ts
Original file line number Diff line number Diff line change
@@ -1,26 +1,20 @@
import type { AnyFn, InternalFns, UseWorkerFnOpts } from "./types.ts";
import type { MsgPort } from "./rpc/types.ts";
import { RpcAgent } from "./rpc/rpc.ts";

/* -------------------------------------------------- common -------------------------------------------------- */

type AwaitedRet<FN extends AnyFn> = Awaited<ReturnType<FN>>;

type ProxyFn<FN extends AnyFn> = (
...args: Parameters<FN>
) => Promise<AwaitedRet<FN>>;

type ProxyFns<FNS extends Record<string, AnyFn>> = {
[P in keyof FNS]: ProxyFn<FNS[P]>;
};
import type {
AnyFn,
InternalFns,
ProxyFn,
ProxyFns,
UseWorkerFnOpts,
} from "./types.ts";

/* -------------------------------------------------- useWorkerFn() -------------------------------------------------- */

/**
* Invoke this function in the main thread to create a proxy function that calls the corresponding worker function.
*
* @param name - The name that identifies the worker function.
* @param worker - Either a Worker instance or an object containing options for creating a lazy Worker instance.
* @param worker - A Worker instance.
* @param options - An object containing options.
* @returns The proxy function.
*/
Expand Down Expand Up @@ -75,10 +69,10 @@ export function useWorkerFns<FNS extends Record<string, AnyFn>>(
/**
* Inspect a worker.
*
* @param worker The worker.
* @param worker A worker instance.
* @returns Information about the worker.
*/
export async function inspectWorker(worker: Worker): Promise<{
export async function inspectWorker(worker: MsgPort): Promise<{
names: string[];
}> {
const rpcAgent = RpcAgent.getRpcAgent(worker);
Expand Down
Loading

0 comments on commit 762d90f

Please sign in to comment.