Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
61 changes: 19 additions & 42 deletions packages/schema/src/plugins/enhancer/enhance/auth-type-generator.ts
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
import { getIdFields, hasAttribute, isAuthInvocation, isDataModelFieldReference } from '@zenstackhq/sdk';
import { getIdFields, isAuthInvocation, isDataModelFieldReference } from '@zenstackhq/sdk';
import {
DataModel,
DataModelField,
Expand All @@ -18,41 +18,27 @@ export function generateAuthType(model: Model, authModel: DataModel) {
const types = new Map<
string,
{
// scalar fields to directly pick from Prisma-generated type
pickFields: string[];

// relation fields to include
addFields: { name: string; type: string }[];
// relation fields to require
requiredRelations: { name: string; type: string }[];
}
>();

types.set(authModel.name, { pickFields: getIdFields(authModel).map((f) => f.name), addFields: [] });
types.set(authModel.name, { requiredRelations: [] });

const ensureType = (model: string) => {
if (!types.has(model)) {
types.set(model, { pickFields: [], addFields: [] });
}
};

const addPickField = (model: string, field: string) => {
let fields = types.get(model);
if (!fields) {
fields = { pickFields: [], addFields: [] };
types.set(model, fields);
}
if (!fields.pickFields.includes(field)) {
fields.pickFields.push(field);
types.set(model, { requiredRelations: [] });
}
};

const addAddField = (model: string, name: string, type: string, array: boolean) => {
let fields = types.get(model);
if (!fields) {
fields = { pickFields: [], addFields: [] };
fields = { requiredRelations: [] };
types.set(model, fields);
}
if (!fields.addFields.find((f) => f.name === name)) {
fields.addFields.push({ name, type: array ? `${type}[]` : type });
if (!fields.requiredRelations.find((f) => f.name === name)) {
fields.requiredRelations.push({ name, type: array ? `${type}[]` : type });
}
};

Expand All @@ -71,11 +57,6 @@ export function generateAuthType(model: Model, authModel: DataModel) {
const fieldType = memberDecl.type.reference.ref.name;
ensureType(fieldType);
addAddField(exprType.name, memberDecl.name, fieldType, memberDecl.type.array);
} else {
// member is a scalar
if (!isIgnoredField(node.member.ref)) {
addPickField(exprType.name, node.member.$refText);
}
}
}
}
Expand All @@ -88,11 +69,6 @@ export function generateAuthType(model: Model, authModel: DataModel) {
// field is a relation
ensureType(fieldType.name);
addAddField(fieldDecl.$container.name, node.target.$refText, fieldType.name, fieldDecl.type.array);
} else {
if (!isIgnoredField(fieldDecl)) {
// field is a scalar
addPickField(fieldDecl.$container.name, node.target.$refText);
}
}
}
});
Expand All @@ -112,16 +88,21 @@ ${Array.from(types.entries())
.map(([model, fields]) => {
let result = `Partial<_P.${model}>`;

if (fields.pickFields.length > 0) {
result = `WithRequired<${result}, ${fields.pickFields
.map((f) => `'${f}'`)
.join('|')}> & Record<string, unknown>`;
if (model === authModel.name) {
// auth model's id fields are always required
const idFields = getIdFields(authModel).map((f) => f.name);
if (idFields.length > 0) {
result = `WithRequired<${result}, ${idFields.map((f) => `'${f}'`).join('|')}>`;
}
}

if (fields.addFields.length > 0) {
result = `${result} & { ${fields.addFields.map(({ name, type }) => `${name}: ${type}`).join('; ')} }`;
if (fields.requiredRelations.length > 0) {
// merge required relation fields
result = `${result} & { ${fields.requiredRelations.map((f) => `${f.name}: ${f.type}`).join('; ')} }`;
}

result = `${result} & Record<string, unknown>`;

return ` export type ${model} = ${result};`;
})
.join('\n')}
Expand All @@ -145,7 +126,3 @@ function isAuthAccess(node: AstNode): node is Expression {

return false;
}

function isIgnoredField(field: DataModelField | undefined) {
return !!(field && hasAttribute(field, '@ignore'));
}
26 changes: 26 additions & 0 deletions tests/integration/tests/enhancements/with-policy/auth.test.ts
Original file line number Diff line number Diff line change
Expand Up @@ -809,4 +809,30 @@ describe('auth() compile-time test', () => {
}
);
});

it('optional field stays optional', async () => {
await loadSchema(
`
model User {
id Int @id
age Int?
@@allow('all', auth().age > 0)
}
`,
{
compile: true,
extraSourceFiles: [
{
name: 'main.ts',
content: `
import { enhance } from ".zenstack/enhance";
import { PrismaClient } from '@prisma/client';
enhance(new PrismaClient(), { user: { id: 1 } });
`,
},
],
}
);
});
});