Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion biome.json
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
{
"$schema": "https://biomejs.dev/schemas/2.2.4/schema.json",
"$schema": "https://biomejs.dev/schemas/2.2.5/schema.json",
"vcs": {
"enabled": false,
"clientKind": "git",
Expand Down
6 changes: 1 addition & 5 deletions build.config.ts
Original file line number Diff line number Diff line change
Expand Up @@ -3,11 +3,7 @@ import { defineBuildConfig } from "unbuild";
export default defineBuildConfig({
rollup: {
esbuild: {
tsconfigRaw: {
compilerOptions: {
experimentalDecorators: true,
},
},
tsconfigRaw: { compilerOptions: { experimentalDecorators: true } },
},
},
});
151 changes: 39 additions & 112 deletions bun.lock

Large diffs are not rendered by default.

8 changes: 4 additions & 4 deletions package.json
Original file line number Diff line number Diff line change
Expand Up @@ -32,12 +32,12 @@
"dist"
],
"devDependencies": {
"@apollo/server": "^4.12.2",
"@apollo/server": "^5.0.0",
"@as-integrations/express5": "^1.1.2",
"@biomejs/biome": "2.2.4",
"@biomejs/biome": "2.2.5",
"@faker-js/faker": "^10.0.0",
"@nestjs/apollo": "^13.1.0",
"@nestjs/graphql": "^13.1.0",
"@nestjs/apollo": "^13.2.1",
"@nestjs/graphql": "^13.2.0",
"@nestjs/platform-express": "^11.1.6",
"@nestjs/testing": "^11.1.6",
"@swc/cli": "^0.7.8",
Expand Down
122 changes: 61 additions & 61 deletions src/auth-module.ts
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,7 @@ import {
DiscoveryService,
HttpAdapterHost,
MetadataScanner,
APP_GUARD,
} from "@nestjs/core";
import { toNodeHandler } from "better-auth/node";
import { createAuthMiddleware } from "better-auth/plugins";
Expand All @@ -25,7 +26,6 @@ import { AuthService } from "./auth-service.ts";
import { SkipBodyParsingMiddleware } from "./middlewares.ts";
import { AFTER_HOOK_KEY, BEFORE_HOOK_KEY, HOOK_KEY } from "./symbols.ts";
import { AuthGuard } from "./auth-guard.ts";
import { APP_GUARD } from "@nestjs/core";

const HOOKS = [
{ metadataKey: BEFORE_HOOK_KEY, hookType: "before" as const },
Expand Down Expand Up @@ -70,15 +70,19 @@ export class AuthModule
);

const hasHookProviders = providers.length > 0;
const hooksConfigured =
typeof this.options.auth?.options?.hooks === "object";
const hooks = this.options.auth?.options?.hooks;
// Check if hooks is a valid object (not null, not undefined)
const hooksConfigured = hooks && typeof hooks === "object";

if (hasHookProviders && !hooksConfigured)
// Only throw error if there are hook providers but hooks is not properly configured
if (hasHookProviders && !hooksConfigured) {
throw new Error(
"Detected @Hook providers but Better Auth 'hooks' are not configured. Add 'hooks: {}' to your betterAuth(...) options.",
);
}

if (!hooksConfigured) return;
// Return early if no hook providers - no need to set up hooks
if (!hasHookProviders) return;

for (const provider of providers) {
const providerPrototype = Object.getPrototypeOf(provider.instance);
Expand All @@ -92,31 +96,31 @@ export class AuthModule
}

configure(consumer: MiddlewareConsumer): void {
const trustedOrigins = this.options.auth.options.trustedOrigins;
// function-based trustedOrigins requires a Request (from web-apis) object to evaluate, which is not available in NestJS (we only have a express Request object)
// if we ever need this, take a look at better-call which show an implementation for this
const isNotFunctionBased = trustedOrigins && Array.isArray(trustedOrigins);
const trustedOrigins = this.options.auth?.options?.trustedOrigins;

// Handle CORS configuration based on trustedOrigins
if (trustedOrigins && !this.options.disableTrustedOriginsCors) {
// function-based trustedOrigins requires a Request (from web-apis) object to evaluate,
// which is not available in NestJS (we only have an express Request object)
// if we ever need this, take a look at better-call which shows an implementation for this
if (!Array.isArray(trustedOrigins)) {
throw new Error(
"Function-based trustedOrigins not supported in NestJS. Use string array or disable CORS with disableTrustedOriginsCors: true.",
);
}

if (!this.options.disableTrustedOriginsCors && isNotFunctionBased) {
this.adapter.httpAdapter.enableCors({
origin: trustedOrigins,
methods: ["GET", "POST", "PUT", "DELETE"],
credentials: true,
});
} else if (
trustedOrigins &&
!this.options.disableTrustedOriginsCors &&
!isNotFunctionBased
)
throw new Error(
"Function-based trustedOrigins not supported in NestJS. Use string array or disable CORS with disableTrustedOriginsCors: true.",
);
}

if (!this.options.disableBodyParser)
consumer.apply(SkipBodyParsingMiddleware).forRoutes("*path");

// Get basePath from options or use default
let basePath = this.options.auth.options.basePath ?? "/api/auth";
let basePath = this.options.auth?.options?.basePath ?? "/api/auth";

// Ensure basePath starts with /
if (!basePath.startsWith("/")) {
Expand All @@ -141,47 +145,30 @@ export class AuthModule

private setupHooks(
providerMethod: (...args: unknown[]) => unknown,
providerClass: { new (...args: unknown[]): unknown },
providerInstance: unknown,
) {
if (!this.options.auth.options.hooks) return;
const hooks = this.options.auth.options.hooks;

for (const { metadataKey, hookType } of HOOKS) {
const hasHook = Reflect.hasMetadata(metadataKey, providerMethod);
if (!hasHook) continue;

const hookPath = Reflect.getMetadata(metadataKey, providerMethod);

const originalHook = this.options.auth.options.hooks[hookType];
this.options.auth.options.hooks[hookType] = createAuthMiddleware(
async (ctx) => {
if (originalHook) {
await originalHook(ctx);
}
const originalHook = hooks[hookType];
hooks[hookType] = createAuthMiddleware(async (ctx) => {
if (originalHook) await originalHook(ctx);

if (hookPath && hookPath !== ctx.path) return;
if (hookPath && hookPath !== ctx.path) return;

await providerMethod.apply(providerClass, [ctx]);
},
);
await providerMethod.apply(providerInstance, [ctx]);
});
}
}

static forRootAsync(options: typeof ASYNC_OPTIONS_TYPE): DynamicModule {
const forRootAsyncResult = super.forRootAsync(options);
return {
...super.forRootAsync(options),
providers: [
...(forRootAsyncResult.providers ?? []),
...(!options.disableGlobalAuthGuard
? [
{
provide: APP_GUARD,
useClass: AuthGuard,
},
]
: []),
],
};
const module = super.forRootAsync(options);
return this.addGuardProvider(module, options.disableGlobalAuthGuard);
}

static forRoot(options: typeof OPTIONS_TYPE): DynamicModule;
Expand All @@ -196,25 +183,38 @@ export class AuthModule
arg1: Auth | typeof OPTIONS_TYPE,
arg2?: Omit<typeof OPTIONS_TYPE, "auth">,
): DynamicModule {
const normalizedOptions: typeof OPTIONS_TYPE =
typeof arg1 === "object" && arg1 !== null && "auth" in (arg1 as object)
? (arg1 as typeof OPTIONS_TYPE)
: ({ ...(arg2 ?? {}), auth: arg1 as Auth } as typeof OPTIONS_TYPE);
// Check if using new format: forRoot({ auth, ...options })
const isNewFormat =
typeof arg1 === "object" && arg1 !== null && "auth" in arg1;

const normalizedOptions: typeof OPTIONS_TYPE = isNewFormat
? (arg1 as typeof OPTIONS_TYPE)
: { ...(arg2 ?? {}), auth: arg1 };

const module = super.forRoot(normalizedOptions);
return this.addGuardProvider(
module,
normalizedOptions.disableGlobalAuthGuard,
);
}

const forRootResult = super.forRoot(normalizedOptions);
/**
* Adds the global AuthGuard provider if not disabled
*/
private static addGuardProvider(
module: DynamicModule,
disableGuard?: boolean,
): DynamicModule {
if (disableGuard) return module;

return {
...forRootResult,
...module,
providers: [
...(forRootResult.providers ?? []),
...(!normalizedOptions.disableGlobalAuthGuard
? [
{
provide: APP_GUARD,
useClass: AuthGuard,
},
]
: []),
...(module.providers ?? []),
{
provide: APP_GUARD,
useClass: AuthGuard,
},
],
};
}
Expand Down
2 changes: 1 addition & 1 deletion src/middlewares.ts
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
import { Injectable, type NestMiddleware } from "@nestjs/common";
import type { NextFunction, Request, Response } from "express";
import * as express from "express";
import express from "express";

@Injectable()
export class SkipBodyParsingMiddleware implements NestMiddleware {
Expand Down
8 changes: 4 additions & 4 deletions src/symbols.ts
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
export const BEFORE_HOOK_KEY = Symbol("BEFORE_HOOK") as symbol;
export const AFTER_HOOK_KEY = Symbol("AFTER_HOOK") as symbol;
export const HOOK_KEY = Symbol("HOOK") as symbol;
export const AUTH_MODULE_OPTIONS_KEY = Symbol("AUTH_MODULE_OPTIONS") as symbol;
export const BEFORE_HOOK_KEY = Symbol("BEFORE_HOOK");
export const AFTER_HOOK_KEY = Symbol("AFTER_HOOK");
export const HOOK_KEY = Symbol("HOOK");
export const AUTH_MODULE_OPTIONS_KEY = Symbol("AUTH_MODULE_OPTIONS");
16 changes: 12 additions & 4 deletions src/utils.ts
Original file line number Diff line number Diff line change
@@ -1,16 +1,24 @@
import type { ExecutionContext } from "@nestjs/common";
import { GqlExecutionContext, type GqlContextType } from "@nestjs/graphql";
import type { Request } from "express";

// Session is typed as unknown because users can extend the session type
type RequestWithSession = Request & { session: unknown };

/**
* Extracts the request object from either HTTP or GraphQL execution context
* @param context - The execution context
* @returns The request object
*/
export function getRequestFromContext(context: ExecutionContext) {
export function getRequestFromContext(
context: ExecutionContext,
): RequestWithSession {
const contextType = context.getType<GqlContextType>();
if (contextType === "graphql") {
return GqlExecutionContext.create(context).getContext().req;
const { req } = GqlExecutionContext.create(context).getContext<{
req: RequestWithSession;
}>();
return req;
}

return context.switchToHttp().getRequest();
return context.switchToHttp().getRequest<RequestWithSession>();
}
Loading
Loading