Skip to content

Commit 3c6e5ee

Browse files
committed
more fixes
1 parent 7c0277d commit 3c6e5ee

File tree

7 files changed

+202
-52
lines changed

7 files changed

+202
-52
lines changed

packages/runtime/src/enhancements/policy/constraint-solver.ts

Lines changed: 41 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -57,7 +57,7 @@ export class ConstraintSolver {
5757
(c) => this.buildVariableFormula(c)
5858
)
5959
.when(
60-
(c): c is ComparisonConstraint => ['eq', 'gt', 'gte', 'lt', 'lte'].includes(c.kind),
60+
(c): c is ComparisonConstraint => ['eq', 'ne', 'gt', 'gte', 'lt', 'lte'].includes(c.kind),
6161
(c) => this.buildComparisonFormula(c)
6262
)
6363
.when(
@@ -71,24 +71,51 @@ export class ConstraintSolver {
7171

7272
private buildLogicalFormula(constraint: LogicalConstraint) {
7373
return match(constraint.kind)
74-
.with('and', () => Logic.and(...constraint.children.map((c) => this.buildFormula(c))))
75-
.with('or', () => Logic.or(...constraint.children.map((c) => this.buildFormula(c))))
76-
.with('not', () => {
77-
if (constraint.children.length !== 1) {
78-
throw new Error('"not" constraint must have exactly one child');
79-
}
80-
return Logic.not(this.buildFormula(constraint.children[0]));
81-
})
74+
.with('and', () => this.buildAndFormula(constraint))
75+
.with('or', () => this.buildOrFormula(constraint))
76+
.with('not', () => this.buildNotFormula(constraint))
8277
.exhaustive();
8378
}
8479

80+
private buildAndFormula(constraint: LogicalConstraint): Logic.Formula {
81+
if (constraint.children.some((c) => this.isFalse(c))) {
82+
// short-circuit
83+
return Logic.FALSE;
84+
}
85+
return Logic.and(...constraint.children.map((c) => this.buildFormula(c)));
86+
}
87+
88+
private buildOrFormula(constraint: LogicalConstraint): Logic.Formula {
89+
if (constraint.children.some((c) => this.isTrue(c))) {
90+
// short-circuit
91+
return Logic.TRUE;
92+
}
93+
return Logic.or(...constraint.children.map((c) => this.buildFormula(c)));
94+
}
95+
96+
private buildNotFormula(constraint: LogicalConstraint) {
97+
if (constraint.children.length !== 1) {
98+
throw new Error('"not" constraint must have exactly one child');
99+
}
100+
return Logic.not(this.buildFormula(constraint.children[0]));
101+
}
102+
103+
private isTrue(constraint: CheckerConstraint): unknown {
104+
return constraint.kind === 'value' && constraint.value === true;
105+
}
106+
107+
private isFalse(constraint: CheckerConstraint): unknown {
108+
return constraint.kind === 'value' && constraint.value === false;
109+
}
110+
85111
private buildComparisonFormula(constraint: ComparisonConstraint) {
86112
if (constraint.left.kind === 'value' && constraint.right.kind === 'value') {
87113
// constant comparison
88114
const left: ValueConstraint = constraint.left;
89115
const right: ValueConstraint = constraint.right;
90116
return match(constraint.kind)
91117
.with('eq', () => (left.value === right.value ? Logic.TRUE : Logic.FALSE))
118+
.with('ne', () => (left.value !== right.value ? Logic.TRUE : Logic.FALSE))
92119
.with('gt', () => (left.value > right.value ? Logic.TRUE : Logic.FALSE))
93120
.with('gte', () => (left.value >= right.value ? Logic.TRUE : Logic.FALSE))
94121
.with('lt', () => (left.value < right.value ? Logic.TRUE : Logic.FALSE))
@@ -98,6 +125,7 @@ export class ConstraintSolver {
98125

99126
return match(constraint.kind)
100127
.with('eq', () => this.transformEquality(constraint.left, constraint.right))
128+
.with('ne', () => this.transformInequality(constraint.left, constraint.right))
101129
.with('gt', () =>
102130
this.transformComparison(constraint.left, constraint.right, (l, r) => Logic.greaterThan(l, r))
103131
)
@@ -177,6 +205,10 @@ export class ConstraintSolver {
177205
}
178206
}
179207

208+
private transformInequality(left: ComparisonTerm, right: ComparisonTerm) {
209+
return Logic.not(this.transformEquality(left, right));
210+
}
211+
180212
private transformComparison(
181213
left: ComparisonTerm,
182214
right: ComparisonTerm,

packages/runtime/src/enhancements/types.ts

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -66,7 +66,7 @@ export type ComparisonTerm = VariableConstraint | ValueConstraint;
6666
* Comparison constraint
6767
*/
6868
export type ComparisonConstraint = {
69-
kind: 'eq' | 'gt' | 'gte' | 'lt' | 'lte';
69+
kind: 'eq' | 'ne' | 'gt' | 'gte' | 'lt' | 'lte';
7070
left: ComparisonTerm;
7171
right: ComparisonTerm;
7272
};

packages/schema/src/plugins/enhancer/policy/constraint-transformer.ts

Lines changed: 91 additions & 33 deletions
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,8 @@
1-
import { ZModelCodeGenerator, isAuthInvocation } from '@zenstackhq/sdk';
1+
import { ZModelCodeGenerator, getRelationKeyPairs, isAuthInvocation, isDataModelFieldReference } from '@zenstackhq/sdk';
22
import {
33
BinaryExpr,
44
BooleanLiteral,
5+
DataModelField,
56
Expression,
67
ExpressionType,
78
LiteralExpr,
@@ -160,46 +161,28 @@ export class ConstraintTransformer {
160161
}
161162

162163
private transformComparison(expr: BinaryExpr) {
163-
if (this.isAuthEqualNull(expr)) {
164-
// `auth() == null` => `user === null`
165-
return this.value(`${this.options.authAccessor} === null`, 'boolean');
166-
}
167-
168-
if (this.isAuthNotEqualNull(expr)) {
169-
// `auth() != null` => `user !== null`
170-
return this.value(`${this.options.authAccessor} !== null`, 'boolean');
164+
if (isAuthInvocation(expr.left) || isAuthInvocation(expr.right)) {
165+
// handle the case if any operand is `auth()` invocation
166+
const authComparison = this.transformAuthComparison(expr);
167+
return authComparison ?? this.nextVar();
171168
}
172169

173170
const leftOperand = this.getComparisonOperand(expr.left);
174171
const rightOperand = this.getComparisonOperand(expr.right);
175172

176-
const op = match(expr.operator)
177-
.with('==', () => 'eq')
178-
.with('!=', () => 'eq')
179-
.with('<', () => 'lt')
180-
.with('<=', () => 'lte')
181-
.with('>', () => 'gt')
182-
.with('>=', () => 'gte')
183-
.otherwise(() => {
184-
throw new Error(`Unsupported operator: ${expr.operator}`);
185-
});
186-
187-
let result = `{ kind: '${op}', left: ${leftOperand}, right: ${rightOperand} }`;
188-
if (expr.operator === '!=') {
189-
// transform "!=" into "not eq"
190-
result = this.not(result);
191-
}
173+
const op = this.mapOperatorToConstraintKind(expr.operator);
174+
const result = `{ kind: '${op}', left: ${leftOperand}, right: ${rightOperand} }`;
192175

193-
// `auth()` access can be undefined, when that happens, we assume a false condition
194-
// for the comparison, unless we're directly comparing `auth() != null`
176+
// `auth()` member access can be undefined, when that happens, we assume a false condition
177+
// for the comparison
195178

196179
const leftAuthAccess = this.getAuthAccess(expr.left);
197180
const rightAuthAccess = this.getAuthAccess(expr.right);
198181

199-
if (leftAuthAccess) {
182+
if (leftAuthAccess && rightOperand) {
200183
// `auth().f op x` => `auth().f !== undefined && auth().f op x`
201184
return this.and(this.value(`${this.normalizeToNull(leftAuthAccess)} !== null`, 'boolean'), result);
202-
} else if (rightAuthAccess) {
185+
} else if (rightAuthAccess && leftOperand) {
203186
// `x op auth().f` => `auth().f !== undefined && x op auth().f`
204187
return this.and(this.value(`${this.normalizeToNull(rightAuthAccess)} !== null`, 'boolean'), result);
205188
}
@@ -212,6 +195,64 @@ export class ConstraintTransformer {
212195
return result;
213196
}
214197

198+
private transformAuthComparison(expr: BinaryExpr) {
199+
if (this.isAuthEqualNull(expr)) {
200+
// `auth() == null` => `user === null`
201+
return this.value(`${this.options.authAccessor} === null`, 'boolean');
202+
}
203+
204+
if (this.isAuthNotEqualNull(expr)) {
205+
// `auth() != null` => `user !== null`
206+
return this.value(`${this.options.authAccessor} !== null`, 'boolean');
207+
}
208+
209+
// auth() equality check against a relation, translate to id-fk comparison
210+
const operand = isAuthInvocation(expr.left) ? expr.right : expr.left;
211+
if (!isDataModelFieldReference(operand)) {
212+
return undefined;
213+
}
214+
215+
// get id-fk field pairs from the relation field
216+
const relationField = operand.target.ref as DataModelField;
217+
const idFkPairs = getRelationKeyPairs(relationField);
218+
219+
// build id-fk field comparison constraints
220+
const fieldConstraints: string[] = [];
221+
222+
idFkPairs.forEach(({ id, foreignKey }) => {
223+
const idFieldType = this.mapType(id.type.type as ExpressionType);
224+
if (!idFieldType) {
225+
return;
226+
}
227+
const fkFieldType = this.mapType(foreignKey.type.type as ExpressionType);
228+
if (!fkFieldType) {
229+
return;
230+
}
231+
232+
const op = this.mapOperatorToConstraintKind(expr.operator);
233+
const authIdAccess = `${this.options.authAccessor}?.${id.name}`;
234+
235+
fieldConstraints.push(
236+
this.and(
237+
// `auth()?.id != null` guard
238+
this.value(`${this.normalizeToNull(authIdAccess)} !== null`, 'boolean'),
239+
// `auth()?.id [op] fkField`
240+
`{ kind: '${op}', left: ${this.value(authIdAccess, idFieldType)}, right: ${this.variable(
241+
foreignKey.name,
242+
fkFieldType
243+
)} }`
244+
)
245+
);
246+
});
247+
248+
// combine field constraints
249+
if (fieldConstraints.length > 0) {
250+
return this.and(...fieldConstraints);
251+
}
252+
253+
return undefined;
254+
}
255+
215256
// normalize `auth()` access undefined value to null
216257
private normalizeToNull(expr: string) {
217258
return `(${expr} ?? null)`;
@@ -241,7 +282,7 @@ export class ConstraintTransformer {
241282
const fieldAccess = this.getFieldAccess(expr);
242283
if (fieldAccess) {
243284
// model field access is transformed into a named variable
244-
const mappedType = this.mapType(expr);
285+
const mappedType = this.mapExpressionType(expr);
245286
if (mappedType) {
246287
return this.variable(fieldAccess.name, mappedType);
247288
} else {
@@ -251,7 +292,7 @@ export class ConstraintTransformer {
251292

252293
const authAccess = this.getAuthAccess(expr);
253294
if (authAccess) {
254-
const mappedType = this.mapType(expr);
295+
const mappedType = this.mapExpressionType(expr);
255296
if (mappedType) {
256297
return `${this.value(authAccess, mappedType)}`;
257298
} else {
@@ -262,14 +303,31 @@ export class ConstraintTransformer {
262303
return undefined;
263304
}
264305

265-
private mapType(expression: Expression) {
266-
return match(expression.$resolvedType?.decl as ExpressionType)
306+
private mapExpressionType(expression: Expression) {
307+
return this.mapType(expression.$resolvedType?.decl as ExpressionType);
308+
}
309+
310+
private mapType(type: ExpressionType) {
311+
return match(type)
267312
.with('Boolean', () => 'boolean')
268313
.with('Int', () => 'number')
269314
.with('String', () => 'string')
270315
.otherwise(() => undefined);
271316
}
272317

318+
private mapOperatorToConstraintKind(operator: BinaryExpr['operator']) {
319+
return match(operator)
320+
.with('==', () => 'eq')
321+
.with('!=', () => 'ne')
322+
.with('<', () => 'lt')
323+
.with('<=', () => 'lte')
324+
.with('>', () => 'gt')
325+
.with('>=', () => 'gte')
326+
.otherwise(() => {
327+
throw new Error(`Unsupported operator: ${operator}`);
328+
});
329+
}
330+
273331
private getFieldAccess(expr: Expression) {
274332
if (isReferenceExpr(expr)) {
275333
return isDataModelField(expr.target.ref) ? { name: expr.target.$refText } : undefined;

packages/schema/src/plugins/enhancer/policy/policy-guard-generator.ts

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -922,7 +922,7 @@ export class PolicyGenerator {
922922
statements.push(`return ${transformed};`);
923923

924924
const func = sourceFile.addFunction({
925-
name: `${model.name}Checker_${kind}`,
925+
name: `${model.name}$checker$${kind}`,
926926
returnType: 'CheckerConstraint',
927927
parameters: [
928928
{

packages/sdk/src/utils.ts

Lines changed: 23 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -298,9 +298,9 @@ export function isForeignKeyField(field: DataModelField) {
298298
}
299299

300300
/**
301-
* Gets the foreign key fields of the given relation field.
301+
* Gets the foreign key-id field pairs from the given relation field.
302302
*/
303-
export function getForeignKeyFields(relationField: DataModelField) {
303+
export function getRelationKeyPairs(relationField: DataModelField) {
304304
if (!isRelationshipField(relationField)) {
305305
return [];
306306
}
@@ -309,11 +309,31 @@ export function getForeignKeyFields(relationField: DataModelField) {
309309
if (relAttr) {
310310
// find "fields" arg
311311
const fieldsArg = getAttributeArg(relAttr, 'fields');
312+
let fkFields: DataModelField[];
312313
if (fieldsArg && isArrayExpr(fieldsArg)) {
313-
return fieldsArg.items
314+
fkFields = fieldsArg.items
314315
.filter((item): item is ReferenceExpr => isReferenceExpr(item))
315316
.map((item) => item.target.ref as DataModelField);
317+
} else {
318+
return [];
316319
}
320+
321+
// find "references" arg
322+
const referencesArg = getAttributeArg(relAttr, 'references');
323+
let idFields: DataModelField[];
324+
if (referencesArg && isArrayExpr(referencesArg)) {
325+
idFields = referencesArg.items
326+
.filter((item): item is ReferenceExpr => isReferenceExpr(item))
327+
.map((item) => item.target.ref as DataModelField);
328+
} else {
329+
return [];
330+
}
331+
332+
if (idFields.length !== fkFields.length) {
333+
throw new Error(`Relation's references arg and fields are must have equal length`);
334+
}
335+
336+
return idFields.map((idField, i) => ({ id: idField, foreignKey: fkFields[i] }));
317337
}
318338

319339
return [];

tests/integration/tests/enhancements/with-policy/checker.test.ts

Lines changed: 43 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -306,6 +306,48 @@ describe('Permission checker', () => {
306306
await expect(enhance({ id: 1, admin: false }).model.check('update')).toResolveTruthy();
307307
});
308308

309+
it('auth compared with relation field', async () => {
310+
const { enhance } = await loadSchema(
311+
`
312+
model User {
313+
id Int @id @default(autoincrement())
314+
models Model[]
315+
}
316+
317+
model Model {
318+
id Int @id @default(autoincrement())
319+
owner User @relation(fields: [ownerId], references: [id])
320+
ownerId Int
321+
@@allow('read', auth().id == ownerId)
322+
@@allow('create', auth().id != ownerId)
323+
@@allow('update', auth() == owner)
324+
@@allow('delete', auth() != owner)
325+
}
326+
`,
327+
{ preserveTsFiles: true }
328+
);
329+
330+
await expect(enhance().model.check('read')).toResolveFalsy();
331+
await expect(enhance({ id: 1 }).model.check('read')).toResolveTruthy();
332+
await expect(enhance({ id: 1 }).model.check('read', { ownerId: 1 })).toResolveTruthy();
333+
await expect(enhance({ id: 1 }).model.check('read', { ownerId: 2 })).toResolveFalsy();
334+
335+
await expect(enhance().model.check('create')).toResolveFalsy();
336+
await expect(enhance({ id: 1 }).model.check('create')).toResolveTruthy();
337+
await expect(enhance({ id: 1 }).model.check('create', { ownerId: 1 })).toResolveFalsy();
338+
await expect(enhance({ id: 1 }).model.check('create', { ownerId: 2 })).toResolveTruthy();
339+
340+
await expect(enhance().model.check('update')).toResolveFalsy();
341+
await expect(enhance({ id: 1 }).model.check('update')).toResolveTruthy();
342+
await expect(enhance({ id: 1 }).model.check('update', { ownerId: 1 })).toResolveTruthy();
343+
await expect(enhance({ id: 1 }).model.check('update', { ownerId: 2 })).toResolveFalsy();
344+
345+
await expect(enhance().model.check('delete')).toResolveFalsy();
346+
await expect(enhance({ id: 1 }).model.check('delete')).toResolveTruthy();
347+
await expect(enhance({ id: 1 }).model.check('delete', { ownerId: 1 })).toResolveFalsy();
348+
await expect(enhance({ id: 1 }).model.check('delete', { ownerId: 2 })).toResolveTruthy();
349+
});
350+
309351
it('auth null check', async () => {
310352
const { enhance } = await loadSchema(
311353
`
@@ -336,7 +378,7 @@ describe('Permission checker', () => {
336378
await expect(enhance({ id: 1, level: 1 }).model.check('update')).toResolveTruthy();
337379
});
338380

339-
it('auth with relation', async () => {
381+
it('auth with relation access', async () => {
340382
const { enhance } = await loadSchema(
341383
`
342384
model User {

0 commit comments

Comments
 (0)