Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -223,18 +223,36 @@ Configure the `@orpc/nest` module by importing `ORPCModule` in your NestJS appli
import { REQUEST } from '@nestjs/core'
import { onError, ORPCModule } from '@orpc/nest'
import { Request } from 'express' // if you use express adapter
import { experimental_SmartCoercionPlugin as SmartCoercionPlugin } from '@orpc/json-schema'
import { ZodToJsonSchemaConverter } from '@orpc/zod/zod4'

declare module '@orpc/nest' {
/**
* Extend oRPC global context to make it type-safe inside your handlers/middlewares
*/
interface ORPCGlobalContext {
request: Request
}
}

@Module({
imports: [
ORPCModule.forRootAsync({ // or .forRoot
ORPCModule.forRootAsync({ // use forRoot for static configuration
useFactory: (request: Request) => ({
context: { request }, // oRPC context, accessible from middlewares, etc.
eventIteratorKeepAliveInterval: 5000, // 5 seconds
interceptors: [
onError((error) => {
console.error(error)
}),
],
context: { request }, // oRPC context, accessible from middlewares, etc.
eventIteratorKeepAliveInterval: 5000, // 5 seconds
plugins: [
new SmartCoercionPlugin({
schemaConverters: [
new ZodToJsonSchemaConverter(),
],
}),
],
}),
inject: [REQUEST],
}),
Expand All @@ -244,10 +262,7 @@ export class AppModule {}
```

::: info

- **`interceptors`** - [Server-side client interceptors](/docs/client/server-side#lifecycle) for intercepting input, output, and errors.
- **`eventIteratorKeepAliveInterval`** - Keep-alive interval for event streams (see [Event Iterator Keep Alive](/docs/rpc-handler#event-iterator-keep-alive))

These configurations are optional and support most options available in [OpenAPIHandler](/docs/openapi/openapi-handler).
:::

## Create a Type-Safe Client
Expand Down
59 changes: 54 additions & 5 deletions packages/nest/src/implement.test.ts
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@ import { REQUEST } from '@nestjs/core'
import { FastifyAdapter } from '@nestjs/platform-fastify'
import { Test } from '@nestjs/testing'
import { oc, ORPCError } from '@orpc/contract'
import * as StandardOpenAPIClientModule from '@orpc/openapi-client/standard'
import { implement, lazy } from '@orpc/server'
import * as StandardServerNode from '@orpc/standard-server-node'
import supertest from 'supertest'
Expand All @@ -15,6 +16,15 @@ import * as z from 'zod'
import { Implement } from './implement'
import { ORPCModule } from './module'

vi.mock('@orpc/openapi-client/standard', async (importActual) => {
const actual = await importActual<any>()
return {
...actual,
StandardOpenAPIJsonSerializer: vi.fn().mockImplementation((...args: any[]) => new actual.StandardOpenAPIJsonSerializer(...args)),
StandardBracketNotationSerializer: vi.fn().mockImplementation((...args: any[]) => new actual.StandardBracketNotationSerializer(...args)),
}
})

const sendStandardResponseSpy = vi.spyOn(StandardServerNode, 'sendStandardResponse')

beforeEach(() => {
Expand Down Expand Up @@ -387,11 +397,21 @@ describe('@Implement', async () => {

it('works with ORPCModule.forRoot', async () => {
const interceptor = vi.fn(({ next }) => next())
const interceptors = [interceptor]
const customJsonSerializers = [
{
condition: (data: unknown) => data === 'special',
serialize: () => 'SPECIAL_SERIALIZED',
},
]
const moduleRef = await Test.createTestingModule({
imports: [
ORPCModule.forRoot({
interceptors: [interceptor],
interceptors,
context: { customValue: 42 },
eventIteratorKeepAliveComment: '__TEST__',
customJsonSerializers,
maxBracketNotationArrayIndex: 9404,
}),
],
controllers: [ImplProcedureController],
Expand All @@ -410,7 +430,18 @@ describe('@Implement', async () => {
expect(res.statusCode).toEqual(200)
expect(res.body).toEqual('pong')

expect(StandardOpenAPIClientModule.StandardOpenAPIJsonSerializer).toHaveBeenCalledWith(expect.objectContaining({
customJsonSerializers,
}))
// make sure the config object is cloned internally
expect(vi.mocked(StandardOpenAPIClientModule.StandardOpenAPIJsonSerializer).mock.calls[0]![0]?.customJsonSerializers).not.toBe(customJsonSerializers)

expect(StandardOpenAPIClientModule.StandardBracketNotationSerializer).toHaveBeenCalledWith(expect.objectContaining({
maxBracketNotationArrayIndex: 9404,
}))

expect(interceptor).toHaveBeenCalledTimes(1)
expect(interceptor).toHaveBeenCalledWith(expect.objectContaining({ context: { customValue: 42 } }))
expect(sendStandardResponseSpy).toHaveBeenCalledTimes(1)
expect(sendStandardResponseSpy).toHaveBeenCalledWith(expect.anything(), expect.anything(), expect.objectContaining({
eventIteratorKeepAliveComment: '__TEST__',
Expand All @@ -419,15 +450,22 @@ describe('@Implement', async () => {

it('works with ORPCModule.forRootAsync', async () => {
const interceptor = vi.fn(({ next }) => next())
const interceptors = [interceptor]
const customJsonSerializers = [
{
condition: (data: unknown) => data === 'special',
serialize: () => 'SPECIAL_SERIALIZED',
},
]
const moduleRef = await Test.createTestingModule({
imports: [
ORPCModule.forRootAsync({
useFactory: async (request: Request) => ({
interceptors: [interceptor],
interceptors,
eventIteratorKeepAliveComment: '__TEST__',
context: {
request,
},
context: { request, customValue: 42 },
customJsonSerializers,
maxBracketNotationArrayIndex: 23979,
}),
inject: [REQUEST],
}),
Expand All @@ -448,6 +486,16 @@ describe('@Implement', async () => {
expect(res.statusCode).toEqual(200)
expect(res.body).toEqual('pong')

expect(StandardOpenAPIClientModule.StandardOpenAPIJsonSerializer).toHaveBeenCalledWith(expect.objectContaining({
customJsonSerializers,
}))
// make sure the config object is cloned internally
expect(vi.mocked(StandardOpenAPIClientModule.StandardOpenAPIJsonSerializer).mock.calls[0]![0]?.customJsonSerializers).not.toBe(customJsonSerializers)

expect(StandardOpenAPIClientModule.StandardBracketNotationSerializer).toHaveBeenCalledWith(expect.objectContaining({
maxBracketNotationArrayIndex: 23979,
}))

expect(interceptor).toHaveBeenCalledTimes(1)
expect(interceptor).toHaveBeenCalledWith(expect.objectContaining({
context: expect.objectContaining({
Expand All @@ -457,6 +505,7 @@ describe('@Implement', async () => {
'x-custom': 'value',
}),
}),
customValue: 42,
}),
}))
expect(sendStandardResponseSpy).toHaveBeenCalledTimes(1)
Expand Down
77 changes: 32 additions & 45 deletions packages/nest/src/implement.ts
Original file line number Diff line number Diff line change
Expand Up @@ -3,19 +3,17 @@ import type { ContractRouter } from '@orpc/contract'
import type { Router } from '@orpc/server'
import type { StandardParams } from '@orpc/server/standard'
import type { Promisable } from '@orpc/shared'
import type { StandardResponse } from '@orpc/standard-server'
import type { Request, Response } from 'express'
import type { FastifyReply, FastifyRequest } from 'fastify'
import type { Observable } from 'rxjs'
import type { ORPCModuleConfig } from './module'
import { applyDecorators, Delete, Get, Head, Inject, Injectable, Optional, Patch, Post, Put, UseInterceptors } from '@nestjs/common'
import { toORPCError } from '@orpc/client'
import { fallbackContractConfig, isContractProcedure } from '@orpc/contract'
import { StandardBracketNotationSerializer, StandardOpenAPIJsonSerializer, StandardOpenAPISerializer } from '@orpc/openapi-client/standard'
import { StandardOpenAPICodec } from '@orpc/openapi/standard'
import { createProcedureClient, getRouter, isProcedure, ORPCError, unlazy } from '@orpc/server'
import { get } from '@orpc/shared'
import { flattenHeader } from '@orpc/standard-server'
import { getRouter, isProcedure, unlazy } from '@orpc/server'
import { StandardHandler } from '@orpc/server/standard'
import { clone, get } from '@orpc/shared'
import * as StandardServerFastify from '@orpc/standard-server-fastify'
import * as StandardServerNode from '@orpc/standard-server-node'
import { mergeMap } from 'rxjs'
Expand Down Expand Up @@ -90,20 +88,20 @@ export function Implement<T extends ContractRouter<any>>(
}
}

const codec = new StandardOpenAPICodec(
new StandardOpenAPISerializer(
new StandardOpenAPIJsonSerializer(),
new StandardBracketNotationSerializer(),
),
)

type NestParams = Record<string, string | string[]>

@Injectable()
export class ImplementInterceptor implements NestInterceptor {
private readonly config: ORPCModuleConfig

constructor(
@Inject(ORPC_MODULE_CONFIG_SYMBOL) @Optional() private readonly config: ORPCModuleConfig | undefined,
@Inject(ORPC_MODULE_CONFIG_SYMBOL) @Optional() config: ORPCModuleConfig | undefined,
) {
// @Optional() doesn't support default values, so we handle it here.
// We clone the config to prevent conflicts when multiple handlers
// modify the same object through the plugins system.
// TODO: improve plugins system to avoid mutating config directly.
this.config = clone(config) ?? {}
}

intercept(ctx: ExecutionContext, next: CallHandler<any>): Observable<any> {
Expand All @@ -124,40 +122,29 @@ export class ImplementInterceptor implements NestInterceptor {
? StandardServerFastify.toStandardLazyRequest(req, res as FastifyReply)
: StandardServerNode.toStandardLazyRequest(req, res as Response)

const standardResponse: StandardResponse = await (async () => {
let isDecoding = false

try {
const client = createProcedureClient(procedure, this.config)

isDecoding = true
const input = await codec.decode(standardRequest, flattenParams(req.params as NestParams), procedure)
isDecoding = false

const output = await client(input, {
signal: standardRequest.signal,
lastEventId: flattenHeader(standardRequest.headers['last-event-id']),
})

return codec.encode(output, procedure)
const codec = new StandardOpenAPICodec(
new StandardOpenAPISerializer(
new StandardOpenAPIJsonSerializer(this.config),
new StandardBracketNotationSerializer(this.config),
),
)

const handler = new StandardHandler(procedure, {
init: () => {},
match: () => Promise.resolve({ path: [], procedure, params: flattenParams(req.params as NestParams) }),
}, codec, this.config)

const result = await handler.handle(standardRequest, {
context: this.config.context,
})

if (result.matched) {
if ('raw' in res) {
await StandardServerFastify.sendStandardResponse(res, result.response, this.config)
}
catch (e) {
const error = isDecoding && !(e instanceof ORPCError)
? new ORPCError('BAD_REQUEST', {
message: `Malformed request. Ensure the request body is properly formatted and the 'Content-Type' header is set correctly.`,
cause: e,
})
: toORPCError(e)

return codec.encodeError(error)
else {
await StandardServerNode.sendStandardResponse(res, result.response, this.config)
}
})()

if ('raw' in res) {
await StandardServerFastify.sendStandardResponse(res, standardResponse, this.config)
}
else {
await StandardServerNode.sendStandardResponse(res, standardResponse, this.config)
}
}),
)
Expand Down
21 changes: 21 additions & 0 deletions packages/nest/src/index.test.ts
Original file line number Diff line number Diff line change
@@ -0,0 +1,21 @@
import * as ServerModule from '@orpc/server'
import { expect, it, vi } from 'vitest'
import { implement } from './index'

vi.mock('@orpc/server', async (importOriginal) => {
const original = await importOriginal<typeof import('@orpc/server')>()
return {
...original,
implement: vi.fn(original.implement),
}
})

it('implement is aliased', () => {
const contract = { nested: {} }
const options = { dedupeLeadingMiddlewares: false }
const impl = implement(contract, options)

expect(ServerModule.implement).toHaveBeenCalledTimes(1)
expect(ServerModule.implement).toHaveBeenCalledWith(contract, options)
expect(impl).toBe(vi.mocked(ServerModule.implement).mock.results[0]!.value)
})
17 changes: 16 additions & 1 deletion packages/nest/src/index.ts
Original file line number Diff line number Diff line change
@@ -1,9 +1,14 @@
import type { AnyContractRouter } from '@orpc/contract'
import type { BuilderConfig, Context, Implementer } from '@orpc/server'
import type { ORPCGlobalContext } from './module'
import { implement as baseImplement } from '@orpc/server'

export * from './implement'
export { Implement as Impl } from './implement'
export * from './module'
export * from './utils'

export { implement, onError, onFinish, onStart, onSuccess, ORPCError } from '@orpc/server'
export { onError, onFinish, onStart, onSuccess, ORPCError } from '@orpc/server'
export type {
ImplementedProcedure,
Implementer,
Expand All @@ -13,3 +18,13 @@ export type {
RouterImplementer,
RouterImplementerWithMiddlewares,
} from '@orpc/server'

/**
* Alias for `implement` from `@orpc/server` with default context set to `ORPCGlobalContext`
*/
export function implement<T extends AnyContractRouter, TContext extends Context = ORPCGlobalContext>(
contract: T,
config: BuilderConfig = {},
): Implementer<T, TContext, TContext> {
return baseImplement(contract, config)
}
26 changes: 22 additions & 4 deletions packages/nest/src/module.ts
Original file line number Diff line number Diff line change
@@ -1,15 +1,33 @@
import type { DynamicModule } from '@nestjs/common'
import type { AnySchema } from '@orpc/contract'
import type { CreateProcedureClientOptions } from '@orpc/server'
import type { StandardBracketNotationSerializerOptions, StandardOpenAPIJsonSerializerOptions } from '@orpc/openapi-client/standard'
import type { StandardHandlerOptions } from '@orpc/server/standard'
import type { SendStandardResponseOptions } from '@orpc/standard-server-node'
import { Module } from '@nestjs/common'
import { ImplementInterceptor } from './implement'

export const ORPC_MODULE_CONFIG_SYMBOL = Symbol('ORPC_MODULE_CONFIG')

/**
* You can extend this interface to add global context properties.
* @example
* ```ts
* declare module '@orpc/nest' {
* interface ORPCGlobalContext {
* user: { id: string; name: string }
* }
* }
* ```
*/
export interface ORPCGlobalContext {

}

export interface ORPCModuleConfig extends
CreateProcedureClientOptions<object, AnySchema, object, object, object>,
SendStandardResponseOptions {
StandardHandlerOptions<ORPCGlobalContext>,
SendStandardResponseOptions,
StandardOpenAPIJsonSerializerOptions,
StandardBracketNotationSerializerOptions {
context?: ORPCGlobalContext
}

@Module({})
Expand Down
6 changes: 5 additions & 1 deletion packages/shared/src/object.test.ts
Original file line number Diff line number Diff line change
Expand Up @@ -76,13 +76,17 @@ it('isTypescriptObject', () => {
it('clone', () => {
expect(clone(null)).toBeNull()

const obj = { a: 1, arr: [2, 3], nested: { arr: [{ b: 4 }] } }
const symbol = Symbol('a')
const obj = { a: 1, arr: [2, 3], nested: { arr: [{ b: 4 }], [symbol]: { [symbol]: 5 } } }
const cloned = clone(obj)

expect(cloned).toEqual(obj)
expect(cloned).not.toBe(obj)
expect(cloned.arr).not.toBe(obj.arr)
expect(cloned.nested.arr).not.toBe(obj.nested.arr)
expect(cloned.nested[symbol]).toEqual(obj.nested[symbol])
expect(cloned.nested[symbol]).not.toBe(obj.nested[symbol])
expect(cloned.nested[symbol][symbol]).toBe(5)
})

it('get', () => {
Expand Down
Loading
Loading