Skip to content

Custom RPC opts mapper #539

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Open
wants to merge 1 commit into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 2 additions & 0 deletions packages/restate-sdk/src/common_api.ts
Original file line number Diff line number Diff line change
Expand Up @@ -32,6 +32,8 @@ export type {
SendClient,
ClientCallOptions,
ClientSendOptions,
ClientCallOptsMapper,
ClientSendOptsMapper,
RemoveVoidArgument,
} from "./types/rpc.js";
export {
Expand Down
34 changes: 29 additions & 5 deletions packages/restate-sdk/src/context_impl.ts
Original file line number Diff line number Diff line change
Expand Up @@ -38,7 +38,12 @@ import {
TerminalError,
UNKNOWN_ERROR_CODE,
} from "./types/errors.js";
import type { Client, SendClient } from "./types/rpc.js";
import type {
Client,
SendClient,
ClientCallOptsMapper,
ClientSendOptsMapper,
} from "./types/rpc.js";
import {
defaultSerde,
HandlerKind,
Expand Down Expand Up @@ -100,7 +105,9 @@ export class ContextImpl implements ObjectContext, WorkflowContext {
private readonly invocationRequest: Request,
private readonly invocationEndPromise: CompletablePromise<void>,
inputReader: ReadableStreamDefaultReader<Uint8Array>,
outputWriter: WritableStreamDefaultWriter<Uint8Array>
outputWriter: WritableStreamDefaultWriter<Uint8Array>,
private readonly clientCallOptsMapper: ClientCallOptsMapper,
private readonly clientSendOptsMapper: ClientSendOptsMapper
) {
this.rand = new RandImpl(input.invocation_id, () => {
// TODO reimplement this check with async context
Expand Down Expand Up @@ -280,21 +287,35 @@ export class ContextImpl implements ObjectContext, WorkflowContext {
}

serviceClient<D>({ name }: ServiceDefinitionFrom<D>): Client<Service<D>> {
return makeRpcCallProxy((call) => this.genericCall(call), name);
return makeRpcCallProxy(
(call) => this.genericCall(call),
this.clientCallOptsMapper,
name
);
}

objectClient<D>(
{ name }: VirtualObjectDefinitionFrom<D>,
key: string
): Client<VirtualObject<D>> {
return makeRpcCallProxy((call) => this.genericCall(call), name, key);
return makeRpcCallProxy(
(call) => this.genericCall(call),
this.clientCallOptsMapper,
name,
key
);
}

workflowClient<D>(
{ name }: WorkflowDefinitionFrom<D>,
key: string
): Client<Workflow<D>> {
return makeRpcCallProxy((call) => this.genericCall(call), name, key);
return makeRpcCallProxy(
(call) => this.genericCall(call),
this.clientCallOptsMapper,
name,
key
);
}

public serviceSendClient<D>(
Expand All @@ -303,6 +324,7 @@ export class ContextImpl implements ObjectContext, WorkflowContext {
): SendClient<Service<D>> {
return makeRpcSendProxy(
(send) => this.genericSend(send),
this.clientSendOptsMapper,
name,
undefined,
opts?.delay
Expand All @@ -316,6 +338,7 @@ export class ContextImpl implements ObjectContext, WorkflowContext {
): SendClient<VirtualObject<D>> {
return makeRpcSendProxy(
(send) => this.genericSend(send),
this.clientSendOptsMapper,
name,
key,
opts?.delay
Expand All @@ -329,6 +352,7 @@ export class ContextImpl implements ObjectContext, WorkflowContext {
): SendClient<Workflow<D>> {
return makeRpcSendProxy(
(send) => this.genericSend(send),
this.clientSendOptsMapper,
name,
key,
opts?.delay
Expand Down
24 changes: 24 additions & 0 deletions packages/restate-sdk/src/endpoint.ts
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,10 @@ import type {
WorkflowDefinition,
} from "@restatedev/restate-sdk-core";
import type { LoggerTransport } from "./logging/logger_transport.js";
import type {
ClientCallOptsMapper,
ClientSendOptsMapper,
} from "./types/rpc.js";

export interface RestateEndpointBase<E> {
/**
Expand Down Expand Up @@ -63,6 +67,26 @@ export interface RestateEndpointBase<E> {
* ```
*/
setLogger(logger: LoggerTransport): E;

/**
* Set a ClientCallOptions mapper function that will be called on every RPC call
*
* Can be used to set default headers, or a default input/output serde e.g. always binary instead of JSON
*
* The mapper function will receive the ClientCallOptions provided by the handler
* and should return either a ClientCallOptions object or undefined
*/
setClientCallOptsMapper(optsMapper: ClientCallOptsMapper): E;

/**
* Set a ClientSendOptions mapper function that will be called on every RPC call
*
* Can be used to set default headers, or a default input/output serde e.g. always binary instead of JSON
*
* The mapper function will receive the ClientSendOptions provided by the handler
* and should return either a ClientSendOptions object or undefined
*/
setClientSendOptsMapper(optsMapper: ClientSendOptsMapper): E;
}

/**
Expand Down
30 changes: 26 additions & 4 deletions packages/restate-sdk/src/endpoint/endpoint_builder.ts
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,11 @@ import type {
WorkflowDefinition,
} from "@restatedev/restate-sdk-core";

import { HandlerWrapper } from "../types/rpc.js";
import type {
ClientCallOptsMapper,
ClientSendOptsMapper,
} from "../types/rpc.js";
import { HandlerWrapper, defaultClientOptsMapper } from "../types/rpc.js";
import type { Component } from "../types/components.js";
import {
ServiceComponent,
Expand Down Expand Up @@ -74,10 +78,22 @@ export class EndpointBuilder {

private _keySet: string[] = [];

private clientCallOptsMapper: ClientCallOptsMapper = defaultClientOptsMapper;

private clientSendOptsMapper: ClientSendOptsMapper = defaultClientOptsMapper;

public get keySet(): string[] {
return this._keySet;
}

public setClientCallOptsMapper(optsMapper: ClientCallOptsMapper) {
this.clientCallOptsMapper = optsMapper;
}

public setClientSendOptsMapper(optsMapper: ClientSendOptsMapper) {
this.clientSendOptsMapper = optsMapper;
}

public componentByName(componentName: string): Component | undefined {
return this.services.get(componentName);
}
Expand Down Expand Up @@ -165,7 +181,9 @@ export class EndpointBuilder {
const component = new ServiceComponent(
name,
definition.description,
definition.metadata
definition.metadata,
this.clientCallOptsMapper,
this.clientSendOptsMapper
);

for (const [route, handler] of Object.entries(
Expand Down Expand Up @@ -193,7 +211,9 @@ export class EndpointBuilder {
const component = new VirtualObjectComponent(
name,
definition.description,
definition.metadata
definition.metadata,
this.clientCallOptsMapper,
this.clientSendOptsMapper
);

for (const [route, handler] of Object.entries(
Expand All @@ -220,7 +240,9 @@ export class EndpointBuilder {
const component = new WorkflowComponent(
name,
definition.description,
definition.metadata
definition.metadata,
this.clientCallOptsMapper,
this.clientSendOptsMapper
);

for (const [route, handler] of Object.entries(
Expand Down
18 changes: 18 additions & 0 deletions packages/restate-sdk/src/endpoint/fetch_endpoint.ts
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,10 @@ import type {
VirtualObjectDefinition,
WorkflowDefinition,
} from "@restatedev/restate-sdk-core";
import type {
ClientCallOptsMapper,
ClientSendOptsMapper,
} from "../types/rpc.js";
import type { Component } from "../types/components.js";
import { EndpointBuilder } from "./endpoint_builder.js";
import type { RestateEndpointBase } from "../endpoint.js";
Expand Down Expand Up @@ -89,6 +93,20 @@ export class FetchEndpointImpl implements FetchEndpoint {
return this;
}

public setClientCallOptsMapper(
optsMapper: ClientCallOptsMapper
): FetchEndpoint {
this.builder.setClientCallOptsMapper(optsMapper);
return this;
}

public setClientSendOptsMapper(
optsMapper: ClientSendOptsMapper
): FetchEndpoint {
this.builder.setClientSendOptsMapper(optsMapper);
return this;
}

public bidirectional(set: boolean = true): FetchEndpoint {
this.protocolMode = set
? ProtocolMode.BIDI_STREAM
Expand Down
8 changes: 6 additions & 2 deletions packages/restate-sdk/src/endpoint/handlers/generic.ts
Original file line number Diff line number Diff line change
Expand Up @@ -304,10 +304,12 @@ export class GenericHandler implements RestateHandler {
attemptCompletedSignal: abortSignal,
};

const handlerComponent = handler.component();
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

In hindsight, this extraction is not necessary, can inline all 3 usages if preferred


// Prepare logger
const loggerContext = new LoggerContext(
input.invocation_id,
handler.component().name(),
handlerComponent.name(),
handler.name(),
handler.kind() === HandlerKind.SERVICE ? undefined : input.key,
invocationRequest,
Expand Down Expand Up @@ -355,7 +357,9 @@ export class GenericHandler implements RestateHandler {
invocationRequest,
invocationEndPromise,
inputReader,
outputWriter
outputWriter,
handlerComponent.clientCallOptsMapper,
handlerComponent.clientSendOptsMapper
);

// Finally invoke user handler
Expand Down
18 changes: 18 additions & 0 deletions packages/restate-sdk/src/endpoint/lambda_endpoint.ts
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,10 @@ import type {
VirtualObjectDefinition,
WorkflowDefinition,
} from "@restatedev/restate-sdk-core";
import type {
ClientCallOptsMapper,
ClientSendOptsMapper,
} from "../types/rpc.js";
import type { Component } from "../types/components.js";
import { EndpointBuilder } from "./endpoint_builder.js";
import type { RestateEndpointBase } from "../endpoint.js";
Expand Down Expand Up @@ -76,6 +80,20 @@ export class LambdaEndpointImpl implements LambdaEndpoint {
return this;
}

public setClientCallOptsMapper(
optsMapper: ClientCallOptsMapper
): LambdaEndpoint {
this.builder.setClientCallOptsMapper(optsMapper);
return this;
}

public setClientSendOptsMapper(
optsMapper: ClientSendOptsMapper
): LambdaEndpoint {
this.builder.setClientSendOptsMapper(optsMapper);
return this;
}

// eslint-disable-next-line @typescript-eslint/no-explicit-any
handler(): (event: any, ctx: any) => Promise<any> {
const genericHandler = new GenericHandler(
Expand Down
18 changes: 18 additions & 0 deletions packages/restate-sdk/src/endpoint/node_endpoint.ts
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,10 @@ import type {
import type { Http2ServerRequest, Http2ServerResponse } from "http2";
import * as http2 from "http2";
import { LambdaHandler } from "./handlers/lambda.js";
import type {
ClientCallOptsMapper,
ClientSendOptsMapper,
} from "../types/rpc.js";
import type { Component } from "../types/components.js";
import { EndpointBuilder } from "./endpoint_builder.js";
import { GenericHandler } from "./handlers/generic.js";
Expand Down Expand Up @@ -66,6 +70,20 @@ export class NodeEndpoint implements RestateEndpoint {
return this;
}

public setClientCallOptsMapper(
optsMapper: ClientCallOptsMapper
): RestateEndpoint {
this.builder.setClientCallOptsMapper(optsMapper);
return this;
}

public setClientSendOptsMapper(
optsMapper: ClientSendOptsMapper
): RestateEndpoint {
this.builder.setClientSendOptsMapper(optsMapper);
return this;
}

http2Handler(): (
request: Http2ServerRequest,
response: Http2ServerResponse
Expand Down
28 changes: 20 additions & 8 deletions packages/restate-sdk/src/types/components.ts
Original file line number Diff line number Diff line change
Expand Up @@ -14,13 +14,19 @@

import * as d from "./discovery.js";
import type { ContextImpl } from "../context_impl.js";
import type { HandlerWrapper } from "./rpc.js";
import type {
HandlerWrapper,
ClientCallOptsMapper,
ClientSendOptsMapper,
} from "./rpc.js";
import { HandlerKind } from "./rpc.js";

//
// Interfaces
//
export interface Component {
clientCallOptsMapper: ClientCallOptsMapper;
clientSendOptsMapper: ClientSendOptsMapper;
name(): string;
handlerMatching(url: InvokePathComponents): ComponentHandler | undefined;
discovery(): d.Service;
Expand All @@ -29,8 +35,8 @@ export interface Component {
export interface ComponentHandler {
name(): string;
component(): Component;
invoke(context: ContextImpl, input: Uint8Array): Promise<Uint8Array>;
kind(): HandlerKind;
invoke(context: ContextImpl, input: Uint8Array): Promise<Uint8Array>;
}

//
Expand Down Expand Up @@ -92,8 +98,10 @@ export class ServiceComponent implements Component {

constructor(
private readonly componentName: string,
public readonly description?: string,
public readonly metadata?: Record<string, string>
public readonly description: string | undefined,
public readonly metadata: Record<string, string> | undefined,
public readonly clientCallOptsMapper: ClientCallOptsMapper,
public readonly clientSendOptsMapper: ClientSendOptsMapper
) {}

name(): string {
Expand Down Expand Up @@ -164,8 +172,10 @@ export class VirtualObjectComponent implements Component {

constructor(
public readonly componentName: string,
public readonly description?: string,
public readonly metadata?: Record<string, string>
public readonly description: string | undefined,
public readonly metadata: Record<string, string> | undefined,
public readonly clientCallOptsMapper: ClientCallOptsMapper,
public readonly clientSendOptsMapper: ClientSendOptsMapper
) {}

name(): string {
Expand Down Expand Up @@ -239,8 +249,10 @@ export class WorkflowComponent implements Component {

constructor(
public readonly componentName: string,
public readonly description?: string,
public readonly metadata?: Record<string, string>
public readonly description: string | undefined,
public readonly metadata: Record<string, string> | undefined,
public readonly clientCallOptsMapper: ClientCallOptsMapper,
public readonly clientSendOptsMapper: ClientSendOptsMapper
) {}

name(): string {
Expand Down
Loading