Skip to content

Commit 0aa6ee9

Browse files
authored
feat: implementing permission checker (#1411)
1 parent fe85134 commit 0aa6ee9

File tree

18 files changed

+1846
-22
lines changed

18 files changed

+1846
-22
lines changed

packages/runtime/package.json

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -29,6 +29,10 @@
2929
"types": "./enhancements/index.d.ts",
3030
"default": "./enhancements/index.js"
3131
},
32+
"./constraint-solver": {
33+
"types": "./constraint-solver.d.ts",
34+
"default": "./constraint-solver.js"
35+
},
3236
"./zod": {
3337
"types": "./zod/index.d.ts",
3438
"default": "./zod/index.js"
@@ -79,12 +83,14 @@
7983
"decimal.js": "^10.4.2",
8084
"deepcopy": "^2.1.0",
8185
"deepmerge": "^4.3.1",
86+
"logic-solver": "^2.0.1",
8287
"lower-case-first": "^2.0.2",
8388
"pluralize": "^8.0.0",
8489
"safe-json-stringify": "^1.2.0",
8590
"semver": "^7.5.2",
8691
"superjson": "^1.11.0",
8792
"tiny-invariant": "^1.3.1",
93+
"ts-pattern": "^4.3.0",
8894
"tslib": "^2.4.1",
8995
"upper-case-first": "^2.0.2",
9096
"uuid": "^9.0.0",
Lines changed: 219 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,219 @@
1+
import Logic from 'logic-solver';
2+
import { match } from 'ts-pattern';
3+
import type {
4+
CheckerConstraint,
5+
ComparisonConstraint,
6+
ComparisonTerm,
7+
LogicalConstraint,
8+
ValueConstraint,
9+
VariableConstraint,
10+
} from '../types';
11+
12+
/**
13+
* A boolean constraint solver based on `logic-solver`. Only boolean and integer types are supported.
14+
*/
15+
export class ConstraintSolver {
16+
// a table for internalizing string literals
17+
private stringTable: string[] = [];
18+
19+
// a map for storing variable names and their corresponding formulas
20+
private variables: Map<string, Logic.Formula> = new Map<string, Logic.Formula>();
21+
22+
/**
23+
* Check the satisfiability of the given constraint.
24+
*/
25+
checkSat(constraint: CheckerConstraint): boolean {
26+
// reset state
27+
this.stringTable = [];
28+
this.variables = new Map<string, Logic.Formula>();
29+
30+
// convert the constraint to a "logic-solver" formula
31+
const formula = this.buildFormula(constraint);
32+
33+
// solve the formula
34+
const solver = new Logic.Solver();
35+
solver.require(formula);
36+
37+
// DEBUG:
38+
// const solution = solver.solve();
39+
// if (solution) {
40+
// console.log('Solution:');
41+
// this.variables.forEach((v, k) => console.log(`\t${k}=${solution?.evaluate(v)}`));
42+
// } else {
43+
// console.log('No solution');
44+
// }
45+
46+
return !!solver.solve();
47+
}
48+
49+
private buildFormula(constraint: CheckerConstraint): Logic.Formula {
50+
return match(constraint)
51+
.when(
52+
(c): c is ValueConstraint => c.kind === 'value',
53+
(c) => this.buildValueFormula(c)
54+
)
55+
.when(
56+
(c): c is VariableConstraint => c.kind === 'variable',
57+
(c) => this.buildVariableFormula(c)
58+
)
59+
.when(
60+
(c): c is ComparisonConstraint => ['eq', 'ne', 'gt', 'gte', 'lt', 'lte'].includes(c.kind),
61+
(c) => this.buildComparisonFormula(c)
62+
)
63+
.when(
64+
(c): c is LogicalConstraint => ['and', 'or', 'not'].includes(c.kind),
65+
(c) => this.buildLogicalFormula(c)
66+
)
67+
.otherwise(() => {
68+
throw new Error(`Unsupported constraint format: ${JSON.stringify(constraint)}`);
69+
});
70+
}
71+
72+
private buildLogicalFormula(constraint: LogicalConstraint) {
73+
return match(constraint.kind)
74+
.with('and', () => this.buildAndFormula(constraint))
75+
.with('or', () => this.buildOrFormula(constraint))
76+
.with('not', () => this.buildNotFormula(constraint))
77+
.exhaustive();
78+
}
79+
80+
private buildAndFormula(constraint: LogicalConstraint): Logic.Formula {
81+
if (constraint.children.some((c) => this.isFalse(c))) {
82+
// short-circuit
83+
return Logic.FALSE;
84+
}
85+
return Logic.and(...constraint.children.map((c) => this.buildFormula(c)));
86+
}
87+
88+
private buildOrFormula(constraint: LogicalConstraint): Logic.Formula {
89+
if (constraint.children.some((c) => this.isTrue(c))) {
90+
// short-circuit
91+
return Logic.TRUE;
92+
}
93+
return Logic.or(...constraint.children.map((c) => this.buildFormula(c)));
94+
}
95+
96+
private buildNotFormula(constraint: LogicalConstraint) {
97+
if (constraint.children.length !== 1) {
98+
throw new Error('"not" constraint must have exactly one child');
99+
}
100+
return Logic.not(this.buildFormula(constraint.children[0]));
101+
}
102+
103+
private isTrue(constraint: CheckerConstraint): unknown {
104+
return constraint.kind === 'value' && constraint.value === true;
105+
}
106+
107+
private isFalse(constraint: CheckerConstraint): unknown {
108+
return constraint.kind === 'value' && constraint.value === false;
109+
}
110+
111+
private buildComparisonFormula(constraint: ComparisonConstraint) {
112+
if (constraint.left.kind === 'value' && constraint.right.kind === 'value') {
113+
// constant comparison
114+
const left: ValueConstraint = constraint.left;
115+
const right: ValueConstraint = constraint.right;
116+
return match(constraint.kind)
117+
.with('eq', () => (left.value === right.value ? Logic.TRUE : Logic.FALSE))
118+
.with('ne', () => (left.value !== right.value ? Logic.TRUE : Logic.FALSE))
119+
.with('gt', () => (left.value > right.value ? Logic.TRUE : Logic.FALSE))
120+
.with('gte', () => (left.value >= right.value ? Logic.TRUE : Logic.FALSE))
121+
.with('lt', () => (left.value < right.value ? Logic.TRUE : Logic.FALSE))
122+
.with('lte', () => (left.value <= right.value ? Logic.TRUE : Logic.FALSE))
123+
.exhaustive();
124+
}
125+
126+
return match(constraint.kind)
127+
.with('eq', () => this.transformEquality(constraint.left, constraint.right))
128+
.with('ne', () => this.transformInequality(constraint.left, constraint.right))
129+
.with('gt', () =>
130+
this.transformComparison(constraint.left, constraint.right, (l, r) => Logic.greaterThan(l, r))
131+
)
132+
.with('gte', () =>
133+
this.transformComparison(constraint.left, constraint.right, (l, r) => Logic.greaterThanOrEqual(l, r))
134+
)
135+
.with('lt', () =>
136+
this.transformComparison(constraint.left, constraint.right, (l, r) => Logic.lessThan(l, r))
137+
)
138+
.with('lte', () =>
139+
this.transformComparison(constraint.left, constraint.right, (l, r) => Logic.lessThanOrEqual(l, r))
140+
)
141+
.exhaustive();
142+
}
143+
144+
private buildVariableFormula(constraint: VariableConstraint) {
145+
return (
146+
match(constraint.type)
147+
.with('boolean', () => this.booleanVariable(constraint.name))
148+
.with('number', () => this.intVariable(constraint.name))
149+
// strings are internalized and represented by their indices
150+
.with('string', () => this.intVariable(constraint.name))
151+
.exhaustive()
152+
);
153+
}
154+
155+
private buildValueFormula(constraint: ValueConstraint) {
156+
return match(constraint.value)
157+
.when(
158+
(v): v is boolean => typeof v === 'boolean',
159+
(v) => (v === true ? Logic.TRUE : Logic.FALSE)
160+
)
161+
.when(
162+
(v): v is number => typeof v === 'number',
163+
(v) => Logic.constantBits(v)
164+
)
165+
.when(
166+
(v): v is string => typeof v === 'string',
167+
(v) => {
168+
// internalize the string and use its index as formula representation
169+
const index = this.stringTable.indexOf(v);
170+
if (index === -1) {
171+
this.stringTable.push(v);
172+
return Logic.constantBits(this.stringTable.length - 1);
173+
} else {
174+
return Logic.constantBits(index);
175+
}
176+
}
177+
)
178+
.exhaustive();
179+
}
180+
181+
private booleanVariable(name: string) {
182+
this.variables.set(name, name);
183+
return name;
184+
}
185+
186+
private intVariable(name: string) {
187+
const r = Logic.variableBits(name, 32);
188+
this.variables.set(name, r);
189+
return r;
190+
}
191+
192+
private transformEquality(left: ComparisonTerm, right: ComparisonTerm) {
193+
if (left.type !== right.type) {
194+
throw new Error(`Type mismatch in equality constraint: ${JSON.stringify(left)}, ${JSON.stringify(right)}`);
195+
}
196+
197+
const leftFormula = this.buildFormula(left);
198+
const rightFormula = this.buildFormula(right);
199+
if (left.type === 'boolean' && right.type === 'boolean') {
200+
// logical equivalence
201+
return Logic.equiv(leftFormula, rightFormula);
202+
} else {
203+
// integer equality
204+
return Logic.equalBits(leftFormula, rightFormula);
205+
}
206+
}
207+
208+
private transformInequality(left: ComparisonTerm, right: ComparisonTerm) {
209+
return Logic.not(this.transformEquality(left, right));
210+
}
211+
212+
private transformComparison(
213+
left: ComparisonTerm,
214+
right: ComparisonTerm,
215+
func: (left: Logic.Formula, right: Logic.Formula) => Logic.Formula
216+
) {
217+
return func(this.buildFormula(left), this.buildFormula(right));
218+
}
219+
}

packages/runtime/src/enhancements/policy/handler.ts

Lines changed: 119 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -2,6 +2,7 @@
22

33
import { lowerCaseFirst } from 'lower-case-first';
44
import invariant from 'tiny-invariant';
5+
import { P, match } from 'ts-pattern';
56
import { upperCaseFirst } from 'upper-case-first';
67
import { fromZodError } from 'zod-validation-error';
78
import { CrudFailureReason } from '../../constants';
@@ -16,13 +17,15 @@ import {
1617
type FieldInfo,
1718
type ModelMeta,
1819
} from '../../cross';
19-
import { PolicyOperationKind, type CrudContract, type DbClientContract } from '../../types';
20+
import { PolicyCrudKind, PolicyOperationKind, type CrudContract, type DbClientContract } from '../../types';
2021
import type { EnhancementContext, InternalEnhancementOptions } from '../create-enhancement';
2122
import { Logger } from '../logger';
2223
import { createDeferredPromise, createFluentPromise } from '../promise';
2324
import { PrismaProxyHandler } from '../proxy';
2425
import { QueryUtils } from '../query-utils';
26+
import type { CheckerConstraint } from '../types';
2527
import { clone, formatObject, isUnsafeMutate, prismaClientValidationError } from '../utils';
28+
import { ConstraintSolver } from './constraint-solver';
2629
import { PolicyUtil } from './policy-utils';
2730

2831
// a record for post-write policy check
@@ -35,6 +38,12 @@ type PostWriteCheckRecord = {
3538

3639
type FindOperations = 'findUnique' | 'findUniqueOrThrow' | 'findFirst' | 'findFirstOrThrow' | 'findMany';
3740

41+
// input arg type for `check` API
42+
type PermissionCheckArgs = {
43+
operation: PolicyCrudKind;
44+
filter?: Record<string, number | string | boolean>;
45+
};
46+
3847
/**
3948
* Prisma proxy handler for injecting access policy check.
4049
*/
@@ -1436,6 +1445,115 @@ export class PolicyProxyHandler<DbClient extends DbClientContract> implements Pr
14361445

14371446
//#endregion
14381447

1448+
//#region Check
1449+
1450+
/**
1451+
* Checks if the given operation is possibly allowed by the policy, without querying the database.
1452+
* @param operation The CRUD operation.
1453+
* @param fieldValues Extra field value filters to be combined with the policy constraints.
1454+
*/
1455+
async check(args: PermissionCheckArgs): Promise<boolean> {
1456+
return createDeferredPromise(() => this.doCheck(args));
1457+
}
1458+
1459+
private async doCheck(args: PermissionCheckArgs) {
1460+
if (!['create', 'read', 'update', 'delete'].includes(args.operation)) {
1461+
throw prismaClientValidationError(this.prisma, this.prismaModule, `Invalid "operation" ${args.operation}`);
1462+
}
1463+
1464+
let constraint = this.policyUtils.getCheckerConstraint(this.model, args.operation);
1465+
if (typeof constraint === 'boolean') {
1466+
return constraint;
1467+
}
1468+
1469+
if (args.filter) {
1470+
// combine runtime filters with generated constraints
1471+
1472+
const extraConstraints: CheckerConstraint[] = [];
1473+
for (const [field, value] of Object.entries(args.filter)) {
1474+
if (value === undefined) {
1475+
continue;
1476+
}
1477+
1478+
if (value === null) {
1479+
throw prismaClientValidationError(
1480+
this.prisma,
1481+
this.prismaModule,
1482+
`Using "null" as filter value is not supported yet`
1483+
);
1484+
}
1485+
1486+
const fieldInfo = requireField(this.modelMeta, this.model, field);
1487+
1488+
// relation and array fields are not supported
1489+
if (fieldInfo.isDataModel || fieldInfo.isArray) {
1490+
throw prismaClientValidationError(
1491+
this.prisma,
1492+
this.prismaModule,
1493+
`Providing filter for field "${field}" is not supported. Only scalar fields are allowed.`
1494+
);
1495+
}
1496+
1497+
// map field type to constraint type
1498+
const fieldType = match<string, 'number' | 'string' | 'boolean'>(fieldInfo.type)
1499+
.with(P.union('Int', 'BigInt', 'Float', 'Decimal'), () => 'number')
1500+
.with('String', () => 'string')
1501+
.with('Boolean', () => 'boolean')
1502+
.otherwise(() => {
1503+
throw prismaClientValidationError(
1504+
this.prisma,
1505+
this.prismaModule,
1506+
`Providing filter for field "${field}" is not supported. Only number, string, and boolean fields are allowed.`
1507+
);
1508+
});
1509+
1510+
// check value type
1511+
const valueType = typeof value;
1512+
if (valueType !== 'number' && valueType !== 'string' && valueType !== 'boolean') {
1513+
throw prismaClientValidationError(
1514+
this.prisma,
1515+
this.prismaModule,
1516+
`Invalid value type for field "${field}". Only number, string or boolean is allowed.`
1517+
);
1518+
}
1519+
1520+
if (fieldType !== valueType) {
1521+
throw prismaClientValidationError(
1522+
this.prisma,
1523+
this.prismaModule,
1524+
`Invalid value type for field "${field}". Expected "${fieldType}".`
1525+
);
1526+
}
1527+
1528+
// check number validity
1529+
if (typeof value === 'number' && (!Number.isInteger(value) || value < 0)) {
1530+
throw prismaClientValidationError(
1531+
this.prisma,
1532+
this.prismaModule,
1533+
`Invalid value for field "${field}". Only non-negative integers are allowed.`
1534+
);
1535+
}
1536+
1537+
// build a constraint
1538+
extraConstraints.push({
1539+
kind: 'eq',
1540+
left: { kind: 'variable', name: field, type: fieldType },
1541+
right: { kind: 'value', value, type: fieldType },
1542+
});
1543+
}
1544+
1545+
if (extraConstraints.length > 0) {
1546+
// combine the constraints
1547+
constraint = { kind: 'and', children: [constraint, ...extraConstraints] };
1548+
}
1549+
}
1550+
1551+
// check satisfiability
1552+
return new ConstraintSolver().checkSat(constraint);
1553+
}
1554+
1555+
//#endregion
1556+
14391557
//#region Utils
14401558

14411559
private get shouldLogQuery() {

0 commit comments

Comments
 (0)