Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
35 changes: 33 additions & 2 deletions packages/runtime/src/enhancements/omit.ts
Original file line number Diff line number Diff line change
Expand Up @@ -26,8 +26,39 @@ class OmitHandler extends DefaultPrismaProxyHandler {
}

// base override
protected async processResultEntity<T>(data: T): Promise<T> {
if (data) {
protected async processResultEntity<T>(method: string, data: T): Promise<T> {
if (!data || typeof data !== 'object') {
return data;
}

if (method === 'subscribe' || method === 'stream') {
if (!('action' in data)) {
return data;
}

// Prisma Pulse result
switch (data.action) {
case 'create':
if ('created' in data) {
await this.doPostProcess(data.created, this.model);
}
break;
case 'update':
if ('before' in data) {
await this.doPostProcess(data.before, this.model);
}
if ('after' in data) {
await this.doPostProcess(data.after, this.model);
}
break;
case 'delete':
if ('deleted' in data) {
await this.doPostProcess(data.deleted, this.model);
}
break;
}
} else {
// regular prisma client result
for (const value of enumerate(data)) {
await this.doPostProcess(value, this.model);
}
Expand Down
125 changes: 87 additions & 38 deletions packages/runtime/src/enhancements/policy/handler.ts
Original file line number Diff line number Diff line change
Expand Up @@ -1537,53 +1537,102 @@ export class PolicyProxyHandler<DbClient extends DbClientContract> implements Pr

//#endregion

//#region Subscribe (Prisma Pulse)
//#region Prisma Pulse

subscribe(args: any) {
return createDeferredPromise(() => {
const readGuard = this.policyUtils.getAuthGuard(this.prisma, this.model, 'read');
if (this.policyUtils.isTrue(readGuard)) {
// no need to inject
if (this.shouldLogQuery) {
this.logger.info(`[policy] \`subscribe\` ${this.model}:\n${formatObject(args)}`);
}
return this.modelClient.subscribe(args);
}

if (!args) {
// include all
args = { create: {}, update: {}, delete: {} };
} else {
if (typeof args !== 'object') {
throw prismaClientValidationError(this.prisma, this.prismaModule, 'argument must be an object');
}
if (Object.keys(args).length === 0) {
// include all
args = { create: {}, update: {}, delete: {} };
} else {
args = this.policyUtils.safeClone(args);
}
}
return this.handleSubscribeStream('subscribe', args);
}

// inject into subscribe conditions
stream(args: any) {
return this.handleSubscribeStream('stream', args);
}

if (args.create) {
args.create.after = this.policyUtils.and(args.create.after, readGuard);
private async handleSubscribeStream(action: 'subscribe' | 'stream', args: any) {
if (!args) {
// include all
args = { create: {}, update: {}, delete: {} };
} else {
if (typeof args !== 'object') {
throw prismaClientValidationError(this.prisma, this.prismaModule, 'argument must be an object');
}
args = this.policyUtils.safeClone(args);
}

if (args.update) {
args.update.after = this.policyUtils.and(args.update.after, readGuard);
// inject read guard as subscription filter
for (const key of ['create', 'update', 'delete']) {
if (args[key] === undefined) {
continue;
}

if (args.delete) {
args.delete.before = this.policyUtils.and(args.delete.before, readGuard);
// "update" has an extra layer of "after"
const payload = key === 'update' ? args[key].after : args[key];
const toInject = { where: payload };
this.policyUtils.injectForRead(this.prisma, this.model, toInject);
if (key === 'update') {
// "update" has an extra layer of "after"
args[key].after = toInject.where;
} else {
args[key] = toInject.where;
}
}

if (this.shouldLogQuery) {
this.logger.info(`[policy] \`subscribe\` ${this.model}:\n${formatObject(args)}`);
}
return this.modelClient.subscribe(args);
});
if (this.shouldLogQuery) {
this.logger.info(`[policy] \`${action}\` ${this.model}:\n${formatObject(args)}`);
}

// Prisma Pulse returns an async iterable, which we need to wrap
// and post-process the iteration results
const iterable = await this.modelClient[action](args);
return {
[Symbol.asyncIterator]: () => {
const iter = iterable[Symbol.asyncIterator].bind(iterable)();
return {
next: async () => {
const { done, value } = await iter.next();
let processedValue = value;
if (value && 'action' in value) {
switch (value.action) {
case 'create':
if ('created' in value) {
processedValue = {
...value,
created: this.policyUtils.postProcessForRead(value.created, this.model, {}),
};
}
break;

case 'update':
if ('before' in value) {
processedValue = {
...value,
before: this.policyUtils.postProcessForRead(value.before, this.model, {}),
};
}
if ('after' in value) {
processedValue = {
...value,
after: this.policyUtils.postProcessForRead(value.after, this.model, {}),
};
}
break;

case 'delete':
if ('deleted' in value) {
processedValue = {
...value,
deleted: this.policyUtils.postProcessForRead(value.deleted, this.model, {}),
};
}
break;
}
}

return { done, value: processedValue };
},
return: () => iter.return?.(),
throw: () => iter.throw?.(),
};
},
};
}

//#endregion
Expand Down
1 change: 1 addition & 0 deletions packages/runtime/src/enhancements/policy/policy-utils.ts
Original file line number Diff line number Diff line change
Expand Up @@ -595,6 +595,7 @@ export class PolicyUtil extends QueryUtils {
// make select and include visible to the injection
const injected: any = { select: args.select, include: args.include };
if (!this.injectAuthGuardAsWhere(db, injected, model, 'read')) {
args.where = this.makeFalse();
return false;
}

Expand Down
36 changes: 31 additions & 5 deletions packages/runtime/src/enhancements/proxy.ts
Original file line number Diff line number Diff line change
Expand Up @@ -54,6 +54,8 @@ export interface PrismaProxyHandler {
count(args: any): Promise<unknown | number>;

subscribe(args: any): Promise<unknown>;

stream(args: any): Promise<unknown>;
}

/**
Expand All @@ -79,7 +81,7 @@ export class DefaultPrismaProxyHandler implements PrismaProxyHandler {
async () => {
args = await this.preprocessArgs(method, args);
const r = await this.prisma[this.model][method](args);
return postProcess ? this.processResultEntity(r) : r;
return postProcess ? this.processResultEntity(method, r) : r;
},
args,
this.options.modelMeta,
Expand All @@ -92,7 +94,7 @@ export class DefaultPrismaProxyHandler implements PrismaProxyHandler {
return createDeferredPromise<TResult>(async () => {
args = await this.preprocessArgs(method, args);
const r = await this.prisma[this.model][method](args);
return postProcess ? this.processResultEntity(r) : r;
return postProcess ? this.processResultEntity(method, r) : r;
});
}

Expand Down Expand Up @@ -161,20 +163,44 @@ export class DefaultPrismaProxyHandler implements PrismaProxyHandler {
}

subscribe(args: any) {
return this.deferred('subscribe', args, false);
return this.doSubscribeStream('subscribe', args);
}

stream(args: any) {
return this.doSubscribeStream('stream', args);
}

private async doSubscribeStream(method: 'subscribe' | 'stream', args: any) {
// Prisma's `subscribe` and `stream` methods return an async iterable
// which we need to wrap to process the iteration results
const iterable = await this.prisma[this.model][method](args);
return {
[Symbol.asyncIterator]: () => {
const iter = iterable[Symbol.asyncIterator].bind(iterable)();
return {
next: async () => {
const { done, value } = await iter.next();
const processedValue = value ? await this.processResultEntity(method, value) : value;
return { done, value: processedValue };
},
return: () => iter.return?.(),
throw: () => iter.throw?.(),
};
},
};
}

/**
* Processes result entities before they're returned
*/
protected async processResultEntity<T>(data: T): Promise<T> {
protected async processResultEntity<T>(_method: PrismaProxyActions, data: T): Promise<T> {
return data;
}

/**
* Processes query args before they're passed to Prisma.
*/
protected async preprocessArgs(method: PrismaProxyActions, args: any) {
protected async preprocessArgs(_method: PrismaProxyActions, args: any) {
return args;
}
}
Expand Down
1 change: 1 addition & 0 deletions packages/runtime/src/types.ts
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,7 @@ export interface DbOperations {
groupBy(args: unknown): Promise<any>;
count(args?: unknown): Promise<any>;
subscribe(args?: unknown): Promise<any>;
stream(args?: unknown): Promise<any>;
check(args: unknown): Promise<boolean>;
fields: Record<string, any>;
}
Expand Down
11 changes: 8 additions & 3 deletions packages/testtools/src/schema.ts
Original file line number Diff line number Diff line change
Expand Up @@ -237,7 +237,7 @@ export async function loadSchema(schema: string, options?: SchemaLoadOptions) {
}

if (opt.pushDb) {
run('npx prisma db push --skip-generate');
run('npx prisma db push --skip-generate --accept-data-loss');
}

if (opt.pulseApiKey) {
Expand All @@ -264,10 +264,10 @@ export async function loadSchema(schema: string, options?: SchemaLoadOptions) {
// https://github.com/prisma/prisma/issues/18292
prisma[Symbol.for('nodejs.util.inspect.custom')] = 'PrismaClient';

const prismaModule = require(path.join(projectDir, 'node_modules/@prisma/client')).Prisma;
const prismaModule = loadModule('@prisma/client', projectDir).Prisma;

if (opt.pulseApiKey) {
const withPulse = require(path.join(projectDir, 'node_modules/@prisma/extension-pulse/dist/cjs')).withPulse;
const withPulse = loadModule('@prisma/extension-pulse/node', projectDir).withPulse;
prisma = prisma.$extends(withPulse({ apiKey: opt.pulseApiKey }));
}

Expand Down Expand Up @@ -388,3 +388,8 @@ export async function loadZModelAndDmmf(
const dmmf = await getDMMF({ datamodel: prismaContent });
return { model, dmmf, modelFile };
}

function loadModule(module: string, basePath: string): any {
const modulePath = require.resolve(module, { paths: [basePath] });
return require(modulePath);
}
Loading