Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
64 changes: 15 additions & 49 deletions packages/runtime/src/enhancements/delegate.ts
Original file line number Diff line number Diff line change
Expand Up @@ -7,13 +7,13 @@ import {
FieldInfo,
ModelInfo,
NestedWriteVisitor,
clone,
enumerate,
getIdFields,
getModelInfo,
isDelegateModel,
resolveField,
} from '../cross';
import { clone } from '../cross';
import type { CrudContract, DbClientContract } from '../types';
import type { InternalEnhancementOptions } from './create-enhancement';
import { Logger } from './logger';
Expand Down Expand Up @@ -79,7 +79,7 @@ export class DelegateProxyHandler extends DefaultPrismaProxyHandler {

if (args.orderBy) {
// `orderBy` may contain fields from base types
args.orderBy = this.buildWhereHierarchy(this.model, args.orderBy);
this.injectWhereHierarchy(this.model, args.orderBy);
}

if (this.options.logPrismaQuery) {
Expand All @@ -95,7 +95,7 @@ export class DelegateProxyHandler extends DefaultPrismaProxyHandler {
}

private injectWhereHierarchy(model: string, where: any) {
if (!where || typeof where !== 'object') {
if (!where || !isPlainObject(where)) {
return;
}

Expand All @@ -108,44 +108,9 @@ export class DelegateProxyHandler extends DefaultPrismaProxyHandler {

const fieldInfo = resolveField(this.options.modelMeta, model, field);
if (!fieldInfo?.inheritedFrom) {
return;
}

let base = this.getBaseModel(model);
let target = where;

while (base) {
const baseRelationName = this.makeAuxRelationName(base);

// prepare base layer where
let thisLayer: any;
if (target[baseRelationName]) {
thisLayer = target[baseRelationName];
} else {
thisLayer = target[baseRelationName] = {};
}

if (base.name === fieldInfo.inheritedFrom) {
thisLayer[field] = value;
delete where[field];
break;
} else {
target = thisLayer;
base = this.getBaseModel(base.name);
if (fieldInfo?.isDataModel) {
this.injectWhereHierarchy(fieldInfo.type, value);
}
}
});
}

private buildWhereHierarchy(model: string, where: any) {
if (!where) {
return undefined;
}

where = clone(where);
Object.entries(where).forEach(([field, value]) => {
const fieldInfo = resolveField(this.options.modelMeta, model, field);
if (!fieldInfo?.inheritedFrom) {
return;
}

Expand All @@ -164,6 +129,9 @@ export class DelegateProxyHandler extends DefaultPrismaProxyHandler {
}

if (base.name === fieldInfo.inheritedFrom) {
if (fieldInfo.isDataModel) {
this.injectWhereHierarchy(base.name, value);
}
thisLayer[field] = value;
delete where[field];
break;
Expand All @@ -173,8 +141,6 @@ export class DelegateProxyHandler extends DefaultPrismaProxyHandler {
}
}
});

return where;
}

private injectSelectIncludeHierarchy(model: string, args: any) {
Expand All @@ -189,7 +155,7 @@ export class DelegateProxyHandler extends DefaultPrismaProxyHandler {
if (fieldInfo && value !== undefined) {
if (value?.orderBy) {
// `orderBy` may contain fields from base types
value.orderBy = this.buildWhereHierarchy(fieldInfo.type, value.orderBy);
this.injectWhereHierarchy(fieldInfo.type, value.orderBy);
}

if (this.injectBaseFieldSelect(model, field, value, args, kind)) {
Expand Down Expand Up @@ -921,15 +887,15 @@ export class DelegateProxyHandler extends DefaultPrismaProxyHandler {
args = clone(args);

if (args.cursor) {
args.cursor = this.buildWhereHierarchy(this.model, args.cursor);
this.injectWhereHierarchy(this.model, args.cursor);
}

if (args.orderBy) {
args.orderBy = this.buildWhereHierarchy(this.model, args.orderBy);
this.injectWhereHierarchy(this.model, args.orderBy);
}

if (args.where) {
args.where = this.buildWhereHierarchy(this.model, args.where);
this.injectWhereHierarchy(this.model, args.where);
}

if (this.options.logPrismaQuery) {
Expand All @@ -949,11 +915,11 @@ export class DelegateProxyHandler extends DefaultPrismaProxyHandler {
args = clone(args);

if (args?.cursor) {
args.cursor = this.buildWhereHierarchy(this.model, args.cursor);
this.injectWhereHierarchy(this.model, args.cursor);
}

if (args?.where) {
args.where = this.buildWhereHierarchy(this.model, args.where);
this.injectWhereHierarchy(this.model, args.where);
}

if (this.options.logPrismaQuery) {
Expand Down Expand Up @@ -989,7 +955,7 @@ export class DelegateProxyHandler extends DefaultPrismaProxyHandler {
args = clone(args);

if (args.where) {
args.where = this.buildWhereHierarchy(this.model, args.where);
this.injectWhereHierarchy(this.model, args.where);
}

if (this.options.logPrismaQuery) {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -284,6 +284,70 @@ describe('Polymorphism Test', () => {
});
});

it('read with compound filter', async () => {
const { enhance } = await loadSchema(
`
model Base {
id Int @id @default(autoincrement())
type String
viewCount Int
@@delegate(type)
}

model Foo extends Base {
name String
}
`,
{ enhancements: ['delegate'] }
);

const db = enhance();
await db.foo.create({ data: { name: 'foo1', viewCount: 0 } });
await db.foo.create({ data: { name: 'foo2', viewCount: 1 } });

await expect(db.foo.findMany({ where: { viewCount: { gt: 0 } } })).resolves.toHaveLength(1);
await expect(db.foo.findMany({ where: { AND: { viewCount: { gt: 0 } } } })).resolves.toHaveLength(1);
await expect(db.foo.findMany({ where: { AND: [{ viewCount: { gt: 0 } }] } })).resolves.toHaveLength(1);
await expect(db.foo.findMany({ where: { OR: [{ viewCount: { gt: 0 } }] } })).resolves.toHaveLength(1);
await expect(db.foo.findMany({ where: { NOT: { viewCount: { lte: 0 } } } })).resolves.toHaveLength(1);
});

it('read with nested filter', async () => {
const { enhance } = await loadSchema(
`
model Base {
id Int @id @default(autoincrement())
type String
viewCount Int
@@delegate(type)
}

model Foo extends Base {
name String
bar Bar?
}

model Bar extends Base {
foo Foo @relation(fields: [fooId], references: [id])
fooId Int @unique
}
`,
{ enhancements: ['delegate'] }
);

const db = enhance();

await db.bar.create({
data: { foo: { create: { name: 'foo', viewCount: 2 } }, viewCount: 1 },
});

await expect(
db.bar.findMany({
where: { viewCount: { gt: 0 }, foo: { viewCount: { gt: 1 } } },
})
).resolves.toHaveLength(1);
});

it('order by base fields', async () => {
const { db, user } = await setup();

Expand Down Expand Up @@ -1013,6 +1077,18 @@ describe('Polymorphism Test', () => {
});
expect(count).toMatchObject({ _all: 1, rating: 1 });

count = await db.ratedVideo.count({
select: { _all: true, rating: true },
where: { AND: { viewCount: { gt: 0 }, rating: { gt: 10 } } },
});
expect(count).toMatchObject({ _all: 1, rating: 1 });

count = await db.ratedVideo.count({
select: { _all: true, rating: true },
where: { AND: [{ viewCount: { gt: 0 }, rating: { gt: 10 } }] },
});
expect(count).toMatchObject({ _all: 1, rating: 1 });

expect(() => db.ratedVideo.count({ select: { rating: true, viewCount: true } })).toThrow(
'count with fields from base type is not supported yet'
);
Expand Down
30 changes: 30 additions & 0 deletions tests/regression/tests/issue-1585.test.ts
Original file line number Diff line number Diff line change
@@ -0,0 +1,30 @@
import { loadSchema } from '@zenstackhq/testtools';
describe('issue 1585', () => {
it('regression', async () => {
const { enhance } = await loadSchema(
`
model Asset {
id Int @id @default(autoincrement())
type String
views Int

@@allow('all', true)
@@delegate(type)
}

model Post extends Asset {
title String
}
`
);

const db = enhance();
await db.post.create({ data: { title: 'Post1', views: 0 } });
await db.post.create({ data: { title: 'Post2', views: 1 } });
await expect(
db.post.count({
where: { views: { gt: 0 } },
})
).resolves.toBe(1);
});
});