Skip to content

Commit

Permalink
fix(core): Fix issues caused by f235249
Browse files Browse the repository at this point in the history
The fix f235249 inadvertently broke transactions across field
resolvers in all databases apart from SQLite. This commit solves that.
  • Loading branch information
michaelbromley committed Sep 25, 2024
1 parent bdd2595 commit 5a4299a
Show file tree
Hide file tree
Showing 9 changed files with 116 additions and 61 deletions.
4 changes: 2 additions & 2 deletions packages/asset-server-plugin/src/common.ts
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
import { REQUEST_CONTEXT_KEY } from '@vendure/core/dist/common/constants';
import { internal_getRequestContext } from '@vendure/core';
import { Request } from 'express';

import { AssetServerOptions, ImageTransformFormat } from './types';
Expand All @@ -18,7 +18,7 @@ export function getAssetUrlPrefixFn(options: AssetServerOptions) {
}
if (typeof assetUrlPrefix === 'function') {
return (request: Request, identifier: string) => {
const ctx = (request as any)[REQUEST_CONTEXT_KEY];
const ctx = internal_getRequestContext(request);
return assetUrlPrefix(ctx, identifier);
};
}
Expand Down
13 changes: 13 additions & 0 deletions packages/core/src/api/common/is-field-resolver.ts
Original file line number Diff line number Diff line change
@@ -0,0 +1,13 @@
import { GraphQLResolveInfo } from 'graphql';

/**
* Returns true is this guard is being called on a FieldResolver, i.e. not a top-level
* Query or Mutation resolver.
*/
export function isFieldResolver(info?: GraphQLResolveInfo): boolean {
if (!info) {
return false;
}
const parentType = info?.parentType?.name;
return parentType !== 'Query' && parentType !== 'Mutation' && parentType !== 'Subscription';
}
82 changes: 70 additions & 12 deletions packages/core/src/api/common/request-context.ts
Original file line number Diff line number Diff line change
Expand Up @@ -4,8 +4,13 @@ import { ID, JsonCompatible } from '@vendure/common/lib/shared-types';
import { isObject } from '@vendure/common/lib/shared-utils';
import { Request } from 'express';
import { TFunction } from 'i18next';
import { EntityManager } from 'typeorm';

import { REQUEST_CONTEXT_KEY, REQUEST_CONTEXT_MAP_KEY } from '../../common/constants';
import {
REQUEST_CONTEXT_KEY,
REQUEST_CONTEXT_MAP_KEY,
TRANSACTION_MANAGER_KEY,
} from '../../common/constants';
import { idsAreEqual } from '../../common/utils';
import { CachedSession } from '../../config/session-cache/session-cache-strategy';
import { Channel } from '../../entity/channel/channel.entity';
Expand All @@ -22,6 +27,32 @@ export type SerializedRequestContext = {
_authorizedAsOwnerOnly: boolean;
};

/**
* This object is used to store the RequestContext on the Express Request object.
*/
interface RequestContextStore {
/**
* This is the default RequestContext for the handler.
*/
default: RequestContext;
/**
* If a transaction is started, the resulting RequestContext is stored here.
* This RequestContext will have a transaction manager attached via the
* TRANSACTION_MANAGER_KEY symbol.
*
* When a transaction is started, the TRANSACTION_MANAGER_KEY symbol is added to the RequestContext
* object. This is then detected inside the {@link internal_setRequestContext} function and the
* RequestContext object is stored in the RequestContextStore under the withTransactionManager key.
*/
withTransactionManager?: RequestContext;
}

interface RequestWithStores extends Request {
// eslint-disable-next-line @typescript-eslint/ban-types
[REQUEST_CONTEXT_MAP_KEY]?: Map<Function, RequestContextStore>;
[REQUEST_CONTEXT_KEY]?: RequestContextStore;
}

/**
* @description
* This function is used to set the {@link RequestContext} on the `req` object. This is the underlying
Expand All @@ -42,23 +73,39 @@ export type SerializedRequestContext = {
* We named it this way to discourage usage outside the framework internals.
*/
export function internal_setRequestContext(
req: Request,
req: RequestWithStores,
ctx: RequestContext,
executionContext?: ExecutionContext,
) {
// If we have access to the `ExecutionContext`, it means we are able to bind
// the `ctx` object to the specific "handler", i.e. the resolver function (for GraphQL)
// or controller (for REST).
let item: RequestContextStore | undefined;
if (executionContext && typeof executionContext.getHandler === 'function') {
// eslint-disable-next-line @typescript-eslint/ban-types
const map: Map<Function, RequestContext> = (req as any)[REQUEST_CONTEXT_MAP_KEY] || new Map();
map.set(executionContext.getHandler(), ctx);
const map = req[REQUEST_CONTEXT_MAP_KEY] || new Map();
item = map.get(executionContext.getHandler());
const ctxHasTransaction = Object.getOwnPropertySymbols(ctx).includes(TRANSACTION_MANAGER_KEY);
if (item) {
item.default = item.default ?? ctx;
if (ctxHasTransaction) {
item.withTransactionManager = ctx;
}
} else {
item = {
default: ctx,
withTransactionManager: ctxHasTransaction ? ctx : undefined,
};
}
map.set(executionContext.getHandler(), item);

(req as any)[REQUEST_CONTEXT_MAP_KEY] = map;
req[REQUEST_CONTEXT_MAP_KEY] = map;
}
// We also bind to a shared key so that we can access the `ctx` object
// later even if we don't have a reference to the `ExecutionContext`
(req as any)[REQUEST_CONTEXT_KEY] = ctx;
req[REQUEST_CONTEXT_KEY] = item ?? {
default: ctx,
};
}

/**
Expand All @@ -67,20 +114,31 @@ export function internal_setRequestContext(
* for more details on this mechanism.
*/
export function internal_getRequestContext(
req: Request,
req: RequestWithStores,
executionContext?: ExecutionContext,
): RequestContext {
let item: RequestContextStore | undefined;
if (executionContext && typeof executionContext.getHandler === 'function') {
// eslint-disable-next-line @typescript-eslint/ban-types
const map: Map<Function, RequestContext> | undefined = (req as any)[REQUEST_CONTEXT_MAP_KEY];
const ctx = map?.get(executionContext.getHandler());
const map = req[REQUEST_CONTEXT_MAP_KEY];
item = map?.get(executionContext.getHandler());
// If we have a ctx associated with the current handler (resolver function), we
// return it. Otherwise, we fall back to the shared key which will be there.
if (ctx) {
return ctx;
if (item) {
return item.withTransactionManager || item.default;
}
}
return (req as any)[REQUEST_CONTEXT_KEY];
if (!item) {
item = req[REQUEST_CONTEXT_KEY] as RequestContextStore;
}
const transactionalCtx =
item?.withTransactionManager &&
((item.withTransactionManager as any)[TRANSACTION_MANAGER_KEY] as EntityManager | undefined)
?.queryRunner?.isReleased === false
? item.withTransactionManager
: undefined;

return transactionalCtx || item.default;
}

/**
Expand Down
9 changes: 4 additions & 5 deletions packages/core/src/api/config/generate-resolvers.ts
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,6 @@ import { StockMovementType } from '@vendure/common/lib/generated-types';
import { GraphQLSchema } from 'graphql';
import { GraphQLDateTime, GraphQLJSON } from 'graphql-scalars';

import { REQUEST_CONTEXT_KEY } from '../../common/constants';
import { InternalServerError } from '../../common/error/errors';
import {
adminErrorOperationTypeResolvers,
Expand All @@ -18,7 +17,7 @@ import { Region } from '../../entity/region/region.entity';
import { getPluginAPIExtensions } from '../../plugin/plugin-metadata';
import { CustomFieldRelationResolverService } from '../common/custom-field-relation-resolver.service';
import { ApiType } from '../common/get-api-type';
import { RequestContext } from '../common/request-context';
import { internal_getRequestContext } from '../common/request-context';
import { userHasPermissionsOnCustomField } from '../common/user-has-permissions-on-custom-field';

import { getCustomFieldsConfigWithoutInterfaces } from './get-custom-fields-config-without-interfaces';
Expand Down Expand Up @@ -206,7 +205,7 @@ function generateCustomFieldRelationResolvers(
let resolver: IFieldResolver<any, any>;
if (isRelationalType(fieldDef)) {
resolver = async (source: any, args: any, context: any) => {
const ctx: RequestContext = context.req[REQUEST_CONTEXT_KEY];
const ctx = internal_getRequestContext(context.req);
if (!userHasPermissionsOnCustomField(ctx, fieldDef)) {
return null;
}
Expand Down Expand Up @@ -235,7 +234,7 @@ function generateCustomFieldRelationResolvers(
};
} else {
resolver = async (source: any, args: any, context: any) => {
const ctx: RequestContext = context.req[REQUEST_CONTEXT_KEY];
const ctx = internal_getRequestContext(context.req);
if (!userHasPermissionsOnCustomField(ctx, fieldDef)) {
return null;
}
Expand Down Expand Up @@ -271,7 +270,7 @@ function generateCustomFieldRelationResolvers(

function getCustomScalars(configService: ConfigService, apiType: 'admin' | 'shop') {
return getPluginAPIExtensions(configService.plugins, apiType)
.map(e => (typeof e.scalars === 'function' ? e.scalars() : e.scalars ?? {}))
.map(e => (typeof e.scalars === 'function' ? e.scalars() : (e.scalars ?? {})))
.reduce(
(all, scalarMap) => ({
...all,
Expand Down
16 changes: 7 additions & 9 deletions packages/core/src/api/decorators/request-context.decorator.ts
Original file line number Diff line number Diff line change
@@ -1,5 +1,7 @@
import { ContextType, createParamDecorator, ExecutionContext } from '@nestjs/common';
import { createParamDecorator, ExecutionContext } from '@nestjs/common';

import { isFieldResolver } from '../common/is-field-resolver';
import { parseContext } from '../common/parse-context';
import { internal_getRequestContext } from '../common/request-context';

/**
Expand All @@ -18,12 +20,8 @@ import { internal_getRequestContext } from '../common/request-context';
* @docsCategory request
* @docsPage Ctx Decorator
*/
export const Ctx = createParamDecorator((data, ctx: ExecutionContext) => {
if (ctx.getType<ContextType | 'graphql'>() === 'graphql') {
// GraphQL request
return internal_getRequestContext(ctx.getArgByIndex(2).req, ctx);
} else {
// REST request
return internal_getRequestContext(ctx.switchToHttp().getRequest(), ctx);
}
export const Ctx = createParamDecorator((data, executionContext: ExecutionContext) => {
const context = parseContext(executionContext);
const handlerIsFieldResolver = context.isGraphQL && isFieldResolver(context.info);
return internal_getRequestContext(context.req, handlerIsFieldResolver ? undefined : executionContext);
});
20 changes: 4 additions & 16 deletions packages/core/src/api/middleware/auth-guard.ts
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,6 @@ import { CanActivate, ExecutionContext, Injectable } from '@nestjs/common';
import { Reflector } from '@nestjs/core';
import { Permission } from '@vendure/common/lib/generated-types';
import { Request, Response } from 'express';
import { GraphQLResolveInfo } from 'graphql';

import { ForbiddenError } from '../../common/error/errors';
import { ConfigService } from '../../config/config.service';
Expand All @@ -14,6 +13,7 @@ import { ChannelService } from '../../service/services/channel.service';
import { CustomerService } from '../../service/services/customer.service';
import { SessionService } from '../../service/services/session.service';
import { extractSessionToken } from '../common/extract-session-token';
import { isFieldResolver } from '../common/is-field-resolver';
import { parseContext } from '../common/parse-context';
import {
internal_getRequestContext,
Expand Down Expand Up @@ -47,16 +47,16 @@ export class AuthGuard implements CanActivate {

async canActivate(context: ExecutionContext): Promise<boolean> {
const { req, res, info } = parseContext(context);
const isFieldResolver = this.isFieldResolver(info);
const targetIsFieldResolver = isFieldResolver(info);
const permissions = this.reflector.get<Permission[]>(PERMISSIONS_METADATA_KEY, context.getHandler());
if (isFieldResolver && !permissions) {
if (targetIsFieldResolver && !permissions) {
return true;
}
const authDisabled = this.configService.authOptions.disableAuth;
const isPublic = !!permissions && permissions.includes(Permission.Public);
const hasOwnerPermission = !!permissions && permissions.includes(Permission.Owner);
let requestContext: RequestContext;
if (isFieldResolver) {
if (targetIsFieldResolver) {
requestContext = internal_getRequestContext(req);
} else {
const session = await this.getSession(req, res, hasOwnerPermission);
Expand Down Expand Up @@ -168,16 +168,4 @@ export class AuthGuard implements CanActivate {
}
return serializedSession;
}

/**
* Returns true is this guard is being called on a FieldResolver, i.e. not a top-level
* Query or Mutation resolver.
*/
private isFieldResolver(info?: GraphQLResolveInfo): boolean {
if (!info) {
return false;
}
const parentType = info?.parentType?.name;
return parentType !== 'Query' && parentType !== 'Mutation' && parentType !== 'Subscription';
}
}
Original file line number Diff line number Diff line change
@@ -1,7 +1,6 @@
import { CallHandler, ExecutionContext, Injectable, NestInterceptor } from '@nestjs/common';
import { ModuleRef } from '@nestjs/core';
import { GqlExecutionContext } from '@nestjs/graphql';
import { LanguageCode } from '@vendure/common/lib/generated-types';
import { getGraphQlInputName } from '@vendure/common/lib/shared-utils';
import {
GraphQLInputType,
Expand All @@ -12,12 +11,11 @@ import {
TypeNode,
} from 'graphql';

import { REQUEST_CONTEXT_KEY } from '../../common/constants';
import { Injector } from '../../common/injector';
import { ConfigService } from '../../config/config.service';
import { CustomFieldConfig, CustomFields } from '../../config/custom-field/custom-field-types';
import { parseContext } from '../common/parse-context';
import { RequestContext } from '../common/request-context';
import { internal_getRequestContext, RequestContext } from '../common/request-context';
import { validateCustomFieldValue } from '../common/validate-custom-field-value';

/**
Expand All @@ -29,7 +27,10 @@ import { validateCustomFieldValue } from '../common/validate-custom-field-value'
export class ValidateCustomFieldsInterceptor implements NestInterceptor {
private readonly inputsWithCustomFields: Set<string>;

constructor(private configService: ConfigService, private moduleRef: ModuleRef) {
constructor(
private configService: ConfigService,
private moduleRef: ModuleRef,
) {
this.inputsWithCustomFields = Object.keys(configService.customFields).reduce((inputs, entityName) => {
inputs.add(`Create${entityName}Input`);
inputs.add(`Update${entityName}Input`);
Expand All @@ -45,7 +46,7 @@ export class ValidateCustomFieldsInterceptor implements NestInterceptor {
const gqlExecutionContext = GqlExecutionContext.create(context);
const { operation, schema } = parsedContext.info;
const variables = gqlExecutionContext.getArgs();
const ctx: RequestContext = (parsedContext.req as any)[REQUEST_CONTEXT_KEY];
const ctx = internal_getRequestContext(parsedContext.req);

if (operation.operation === 'mutation') {
const inputTypeNames = this.getArgumentMap(operation, schema);
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -175,11 +175,7 @@ export class ProductVariantAdminEntityResolver {
}

@ResolveField()
async stockOnHand(
@Ctx() ctx: RequestContext,
@Parent() productVariant: ProductVariant,
@Args() args: { options: StockMovementListOptions },
): Promise<number> {
async stockOnHand(@Ctx() ctx: RequestContext, @Parent() productVariant: ProductVariant): Promise<number> {
const { stockOnHand } = await this.stockLevelService.getAvailableStock(ctx, productVariant.id);
return stockOnHand;
}
Expand Down
16 changes: 9 additions & 7 deletions packages/core/src/connection/transaction-wrapper.ts
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
import { from, lastValueFrom, Observable, of } from 'rxjs';
import { from, lastValueFrom, Observable } from 'rxjs';
import { retryWhen, take, tap } from 'rxjs/operators';
import { Connection, EntityManager, QueryRunner } from 'typeorm';
import { DataSource, EntityManager, QueryRunner } from 'typeorm';
import { TransactionAlreadyStartedError } from 'typeorm/error/TransactionAlreadyStartedError';

import { RequestContext } from '../api/common/request-context';
Expand Down Expand Up @@ -28,13 +28,13 @@ export class TransactionWrapper {
work: (ctx: RequestContext) => Observable<T> | Promise<T>,
mode: TransactionMode,
isolationLevel: TransactionIsolationLevel | undefined,
connection: Connection,
connection: DataSource,
): Promise<T> {
// Copy to make sure original context will remain valid after transaction completes
const ctx = originalCtx.copy();

const entityManager: EntityManager | undefined = (ctx as any)[TRANSACTION_MANAGER_KEY];
const queryRunner = entityManager ?.queryRunner || connection.createQueryRunner();
const queryRunner = entityManager?.queryRunner || connection.createQueryRunner();

if (mode === 'auto') {
await this.startTransaction(queryRunner, isolationLevel);
Expand Down Expand Up @@ -67,8 +67,7 @@ export class TransactionWrapper {
}
throw error;
} finally {
if (!queryRunner.isTransactionActive
&& queryRunner.isReleased === false) {
if (!queryRunner.isTransactionActive && queryRunner.isReleased === false) {
// There is a check for an active transaction
// because this could be a nested transaction (savepoint).

Expand All @@ -81,7 +80,10 @@ export class TransactionWrapper {
* Attempts to start a DB transaction, with retry logic in the case that a transaction
* is already started for the connection (which is mainly a problem with SQLite/Sql.js)
*/
private async startTransaction(queryRunner: QueryRunner, isolationLevel: TransactionIsolationLevel | undefined) {
private async startTransaction(
queryRunner: QueryRunner,
isolationLevel: TransactionIsolationLevel | undefined,
) {
const maxRetries = 25;
let attempts = 0;
let lastError: any;
Expand Down

0 comments on commit 5a4299a

Please sign in to comment.