Skip to content
This repository was archived by the owner on Apr 23, 2025. It is now read-only.

Commit c3b456a

Browse files
authored
fix: improve clarity of dealing with auth() during policy generation (zenstackhq#293)
1 parent 933012f commit c3b456a

File tree

13 files changed

+659
-140
lines changed

13 files changed

+659
-140
lines changed

packages/language/src/ast.ts

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -14,6 +14,7 @@ export type ResolvedShape = ExpressionType | AbstractDeclaration;
1414
export type ResolvedType = {
1515
decl?: ResolvedShape;
1616
array?: boolean;
17+
nullable?: boolean;
1718
};
1819

1920
export const BinaryExprOperatorPriority: Record<BinaryExpr['operator'], number> = {

packages/runtime/src/validation.ts

Lines changed: 14 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -18,3 +18,17 @@ export function validate(validator: z.ZodType, data: unknown) {
1818
throw new ValidationError(fromZodError(err as z.ZodError).message);
1919
}
2020
}
21+
22+
/**
23+
* Check if the given object has all the given fields, not null or undefined
24+
* @param obj
25+
* @param fields
26+
* @returns
27+
*/
28+
// eslint-disable-next-line @typescript-eslint/no-explicit-any
29+
export function hasAllFields(obj: any, fields: string[]) {
30+
if (typeof obj !== 'object' || !obj) {
31+
return false;
32+
}
33+
return fields.every((f) => obj[f] !== undefined && obj[f] !== null);
34+
}

packages/schema/src/language-server/validator/expression-validator.ts

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,6 @@
11
import { BinaryExpr, Expression, isArrayExpr, isBinaryExpr, isEnum, isLiteralExpr } from '@zenstackhq/language/ast';
22
import { ValidationAcceptor } from 'langium';
3-
import { isAuthInvocation, isDataModelFieldReference, isEnumFieldReference } from '../../utils/ast-utils';
3+
import { getDataModelFieldReference, isAuthInvocation, isEnumFieldReference } from '../../utils/ast-utils';
44
import { AstValidator } from '../types';
55

66
/**
@@ -33,7 +33,7 @@ export default class ExpressionValidator implements AstValidator<Expression> {
3333
private validateBinaryExpr(expr: BinaryExpr, accept: ValidationAcceptor) {
3434
switch (expr.operator) {
3535
case 'in': {
36-
if (!isDataModelFieldReference(expr.left)) {
36+
if (!getDataModelFieldReference(expr.left)) {
3737
accept('error', 'left operand of "in" must be a field reference', { node: expr.left });
3838
}
3939

packages/schema/src/language-server/validator/function-invocation-validator.ts

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -8,7 +8,7 @@ import {
88
isLiteralExpr,
99
} from '@zenstackhq/language/ast';
1010
import { ValidationAcceptor } from 'langium';
11-
import { isDataModelFieldReference, isEnumFieldReference } from '../../utils/ast-utils';
11+
import { getDataModelFieldReference, isEnumFieldReference } from '../../utils/ast-utils';
1212
import { FILTER_OPERATOR_FUNCTIONS } from '../constants';
1313
import { AstValidator } from '../types';
1414
import { isFromStdlib } from '../utils';
@@ -38,7 +38,7 @@ export default class FunctionInvocationValidator implements AstValidator<Express
3838
// first argument must refer to a model field
3939
const firstArg = expr.args?.[0]?.value;
4040
if (firstArg) {
41-
if (!isDataModelFieldReference(firstArg)) {
41+
if (!getDataModelFieldReference(firstArg)) {
4242
accept('error', 'first argument must be a field reference', { node: firstArg });
4343
}
4444
}

packages/schema/src/language-server/zmodel-linker.ts

Lines changed: 10 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -15,6 +15,7 @@ import {
1515
isArrayExpr,
1616
isDataModel,
1717
isDataModelField,
18+
isDataModelFieldType,
1819
isReferenceExpr,
1920
LiteralExpr,
2021
MemberAccessExpr,
@@ -249,7 +250,7 @@ export class ZModelLinker extends DefaultLinker {
249250
const model = getContainingModel(node);
250251
const userModel = model?.declarations.find((d) => isDataModel(d) && d.name === 'User');
251252
if (userModel) {
252-
node.$resolvedType = { decl: userModel };
253+
node.$resolvedType = { decl: userModel, nullable: true };
253254
}
254255
} else if (funcDecl.name === 'future' && isFromStdlib(funcDecl)) {
255256
// future() function is resolved to current model
@@ -447,19 +448,24 @@ export class ZModelLinker extends DefaultLinker {
447448
//#region Utils
448449

449450
private resolveToDeclaredType(node: AstNode, type: FunctionParamType | DataModelFieldType) {
451+
let nullable = false;
452+
if (isDataModelFieldType(type)) {
453+
nullable = type.optional;
454+
}
450455
if (type.type) {
451456
const mappedType = mapBuiltinTypeToExpressionType(type.type);
452-
node.$resolvedType = { decl: mappedType, array: type.array };
457+
node.$resolvedType = { decl: mappedType, array: type.array, nullable: nullable };
453458
} else if (type.reference) {
454459
node.$resolvedType = {
455460
decl: type.reference.ref,
456461
array: type.array,
462+
nullable: nullable,
457463
};
458464
}
459465
}
460466

461-
private resolveToBuiltinTypeOrDecl(node: AstNode, type: ResolvedShape, array = false) {
462-
node.$resolvedType = { decl: type, array };
467+
private resolveToBuiltinTypeOrDecl(node: AstNode, type: ResolvedShape, array = false, nullable = false) {
468+
node.$resolvedType = { decl: type, array, nullable };
463469
}
464470

465471
//#endregion

packages/schema/src/plugins/access-policy/expression-writer.ts

Lines changed: 130 additions & 56 deletions
Original file line numberDiff line numberDiff line change
@@ -17,7 +17,7 @@ import {
1717
import { getLiteral, GUARD_FIELD_NAME, PluginError } from '@zenstackhq/sdk';
1818
import { CodeBlockWriter } from 'ts-morph';
1919
import { FILTER_OPERATOR_FUNCTIONS } from '../../language-server/constants';
20-
import { getIdField, isAuthInvocation } from '../../utils/ast-utils';
20+
import { getIdFields, isAuthInvocation } from '../../utils/ast-utils';
2121
import TypeScriptExpressionTransformer from './typescript-expression-transformer';
2222
import { isFutureExpr } from './utils';
2323

@@ -99,12 +99,17 @@ export class ExpressionWriter {
9999

100100
private writeMemberAccess(expr: MemberAccessExpr) {
101101
this.block(() => {
102-
// must be a boolean member
103-
this.writeFieldCondition(expr.operand, () => {
104-
this.block(() => {
105-
this.writer.write(`${expr.member.ref?.name}: true`);
102+
if (this.isAuthOrAuthMemberAccess(expr)) {
103+
// member access of `auth()`, generate plain expression
104+
this.guard(() => this.plain(expr), true);
105+
} else {
106+
// must be a boolean member
107+
this.writeFieldCondition(expr.operand, () => {
108+
this.block(() => {
109+
this.writer.write(`${expr.member.ref?.name}: true`);
110+
});
106111
});
107-
});
112+
}
108113
});
109114
}
110115

@@ -190,9 +195,14 @@ export class ExpressionWriter {
190195
return false;
191196
}
192197

193-
private guard(write: () => void) {
198+
private guard(write: () => void, cast = false) {
194199
this.writer.write(`${GUARD_FIELD_NAME}: `);
195-
write();
200+
if (cast) {
201+
this.writer.write('!!');
202+
write();
203+
} else {
204+
write();
205+
}
196206
}
197207

198208
private plain(expr: Expression) {
@@ -211,12 +221,9 @@ export class ExpressionWriter {
211221
// compile down to a plain expression
212222
this.block(() => {
213223
this.guard(() => {
214-
this.plain(expr.left);
215-
this.writer.write(' ' + operator + ' ');
216-
this.plain(expr.right);
224+
this.plain(expr);
217225
});
218226
});
219-
220227
return;
221228
}
222229

@@ -242,65 +249,105 @@ export class ExpressionWriter {
242249
} as ReferenceExpr;
243250
}
244251

245-
// if the operand refers to auth(), need to build a guard to avoid
246-
// using undefined user as filter (which means no filter to Prisma)
247-
// if auth() evaluates falsy, just treat the condition as false
248-
if (this.isAuthOrAuthMemberAccess(operand)) {
249-
this.writer.write(`!user ? { ${GUARD_FIELD_NAME}: false } : `);
252+
// guard member access of `auth()` with null check
253+
if (this.isAuthOrAuthMemberAccess(operand) && !fieldAccess.$resolvedType?.nullable) {
254+
this.writer.write(
255+
`(${this.plainExprBuilder.transform(operand)} == null) ? { ${GUARD_FIELD_NAME}: ${
256+
// auth().x != user.x is true when auth().x is null and user is not nullable
257+
// other expressions are evaluated to false when null is involved
258+
operator === '!=' ? 'true' : 'false'
259+
} } : `
260+
);
250261
}
251262

252-
this.block(() => {
253-
this.writeFieldCondition(fieldAccess, () => {
254-
this.block(
255-
() => {
263+
this.block(
264+
() => {
265+
this.writeFieldCondition(fieldAccess, () => {
266+
this.block(() => {
256267
const dataModel = this.isModelTyped(fieldAccess);
257-
if (dataModel) {
258-
const idField = getIdField(dataModel);
259-
if (!idField) {
268+
if (dataModel && isAuthInvocation(operand)) {
269+
// right now this branch only serves comparison with `auth`, like
270+
// @@allow('all', owner == auth())
271+
272+
const idFields = getIdFields(dataModel);
273+
if (!idFields || idFields.length === 0) {
260274
throw new PluginError(`Data model ${dataModel.name} does not have an id field`);
261275
}
262-
// comparing with an object, convert to "id" comparison instead
263-
this.writer.write(`${idField.name}: `);
276+
277+
if (operator !== '==' && operator !== '!=') {
278+
throw new PluginError('Only == and != operators are allowed');
279+
}
280+
281+
if (!isThisExpr(fieldAccess)) {
282+
this.writer.writeLine(operator === '==' ? 'is:' : 'isNot:');
283+
const fieldIsNullable = !!fieldAccess.$resolvedType?.nullable;
284+
if (fieldIsNullable) {
285+
// if field is nullable, we can generate "null" check condition
286+
this.writer.write(`(user == null) ? null : `);
287+
}
288+
}
289+
264290
this.block(() => {
265-
this.writeOperator(operator, () => {
266-
// operand ? operand.field : null
267-
this.writer.write('(');
268-
this.plain(operand);
269-
this.writer.write(' ? ');
270-
this.plain(operand);
271-
this.writer.write(`.${idField.name}`);
272-
this.writer.write(' : null');
273-
this.writer.write(')');
291+
idFields.forEach((idField, idx) => {
292+
const writeIdsCheck = () => {
293+
// id: user.id
294+
this.writer.write(`${idField.name}:`);
295+
this.plain(operand);
296+
this.writer.write(`.${idField.name}`);
297+
if (idx !== idFields.length - 1) {
298+
this.writer.write(',');
299+
}
300+
};
301+
302+
if (isThisExpr(fieldAccess) && operator === '!=') {
303+
// wrap a not
304+
this.writer.writeLine('NOT:');
305+
this.block(() => writeIdsCheck());
306+
} else {
307+
writeIdsCheck();
308+
}
274309
});
275310
});
276311
} else {
277-
this.writeOperator(operator, () => {
312+
this.writeOperator(operator, fieldAccess, () => {
278313
this.plain(operand);
279314
});
280315
}
281-
},
282-
// "this" expression is compiled away (to .id access), so we should
283-
// avoid generating a new layer
284-
!isThisExpr(fieldAccess)
285-
);
286-
});
287-
});
316+
}, !isThisExpr(fieldAccess));
317+
});
318+
},
319+
// "this" expression is compiled away (to .id access), so we should
320+
// avoid generating a new layer
321+
!isThisExpr(fieldAccess)
322+
);
288323
}
289324

290325
private isAuthOrAuthMemberAccess(expr: Expression) {
291326
return isAuthInvocation(expr) || (isMemberAccessExpr(expr) && isAuthInvocation(expr.operand));
292327
}
293328

294-
private writeOperator(operator: ComparisonOperator, writeOperand: () => void) {
295-
if (operator === '!=') {
296-
// wrap a 'not'
297-
this.writer.write('not: ');
298-
this.block(() => {
299-
this.writeOperator('==', writeOperand);
300-
});
301-
} else {
302-
this.writer.write(`${this.mapOperator(operator)}: `);
329+
private writeOperator(operator: ComparisonOperator, fieldAccess: Expression, writeOperand: () => void) {
330+
if (isDataModel(fieldAccess.$resolvedType?.decl)) {
331+
if (operator === '==') {
332+
this.writer.write('is: ');
333+
} else if (operator === '!=') {
334+
this.writer.write('isNot: ');
335+
} else {
336+
throw new PluginError('Only == and != operators are allowed for data model comparison');
337+
}
303338
writeOperand();
339+
} else {
340+
if (operator === '!=') {
341+
// wrap a 'not'
342+
this.writer.write('not: ');
343+
this.block(() => {
344+
this.writer.write(`${this.mapOperator('==')}: `);
345+
writeOperand();
346+
});
347+
} else {
348+
this.writer.write(`${this.mapOperator(operator)}: `);
349+
writeOperand();
350+
}
304351
}
305352
}
306353

@@ -414,10 +461,37 @@ export class ExpressionWriter {
414461
}
415462

416463
private writeLogical(expr: BinaryExpr, operator: '&&' | '||') {
417-
this.block(() => {
418-
this.writer.write(`${operator === '&&' ? 'AND' : 'OR'}: `);
419-
this.writeExprList([expr.left, expr.right]);
420-
});
464+
// TODO: do we need short-circuit for logical operators?
465+
466+
if (operator === '&&') {
467+
// // && short-circuit: left && right -> left ? right : { zenstack_guard: false }
468+
// if (!this.hasFieldAccess(expr.left)) {
469+
// this.plain(expr.left);
470+
// this.writer.write(' ? ');
471+
// this.write(expr.right);
472+
// this.writer.write(' : ');
473+
// this.block(() => this.guard(() => this.writer.write('false')));
474+
// } else {
475+
this.block(() => {
476+
this.writer.write('AND:');
477+
this.writeExprList([expr.left, expr.right]);
478+
});
479+
// }
480+
} else {
481+
// // || short-circuit: left || right -> left ? { zenstack_guard: true } : right
482+
// if (!this.hasFieldAccess(expr.left)) {
483+
// this.plain(expr.left);
484+
// this.writer.write(' ? ');
485+
// this.block(() => this.guard(() => this.writer.write('true')));
486+
// this.writer.write(' : ');
487+
// this.write(expr.right);
488+
// } else {
489+
this.block(() => {
490+
this.writer.write('OR:');
491+
this.writeExprList([expr.left, expr.right]);
492+
});
493+
// }
494+
}
421495
}
422496

423497
private writeUnary(expr: UnaryExpr) {

packages/schema/src/plugins/access-policy/policy-guard-generator.ts

Lines changed: 9 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -21,7 +21,7 @@ import path from 'path';
2121
import { FunctionDeclaration, Project, SourceFile, VariableDeclarationKind } from 'ts-morph';
2222
import { name } from '.';
2323
import { isFromStdlib } from '../../language-server/utils';
24-
import { analyzePolicies, getIdField } from '../../utils/ast-utils';
24+
import { analyzePolicies, getIdFields } from '../../utils/ast-utils';
2525
import { ALL_OPERATION_KINDS, getDefaultOutputFolder, RUNTIME_PACKAGE } from '../plugin-utils';
2626
import { ExpressionWriter } from './expression-writer';
2727
import { isFutureExpr } from './utils';
@@ -42,9 +42,8 @@ export default class PolicyGenerator {
4242
const sf = project.createSourceFile(path.join(output, 'policy.ts'), undefined, { overwrite: true });
4343

4444
sf.addImportDeclaration({
45-
namedImports: [{ name: 'QueryContext' }],
45+
namedImports: [{ name: 'type QueryContext' }, { name: 'hasAllFields' }],
4646
moduleSpecifier: `${RUNTIME_PACKAGE}`,
47-
isTypeOnly: true,
4847
});
4948

5049
sf.addImportDeclaration({
@@ -329,13 +328,17 @@ export default class PolicyGenerator {
329328
if (!userModel) {
330329
throw new PluginError('User model not found');
331330
}
332-
const userIdField = getIdField(userModel);
333-
if (!userIdField) {
331+
const userIdFields = getIdFields(userModel);
332+
if (!userIdFields || userIdFields.length === 0) {
334333
throw new PluginError('User model does not have an id field');
335334
}
336335

337336
// normalize user to null to avoid accidentally use undefined in filter
338-
func.addStatements(`const user = context.user ?? null;`);
337+
func.addStatements(
338+
`const user = hasAllFields(context.user, [${userIdFields
339+
.map((f) => "'" + f.name + "'")
340+
.join(', ')}]) ? context.user : null;`
341+
);
339342
}
340343

341344
// r = <guard object>;

0 commit comments

Comments
 (0)