1
1
/* eslint-disable @typescript-eslint/no-explicit-any */
2
2
3
3
import { PrismaClientKnownRequestError , PrismaClientUnknownRequestError } from '@prisma/client/runtime' ;
4
- import { AUXILIARY_FIELDS , CrudFailureReason , TRANSACTION_FIELD_NAME } from '@zenstackhq/sdk' ;
4
+ import { AUXILIARY_FIELDS , CrudFailureReason , GUARD_FIELD_NAME , TRANSACTION_FIELD_NAME } from '@zenstackhq/sdk' ;
5
5
import { camelCase } from 'change-case' ;
6
6
import cuid from 'cuid' ;
7
7
import deepcopy from 'deepcopy' ;
@@ -42,8 +42,7 @@ export class PolicyUtil {
42
42
and ( ...conditions : ( boolean | object ) [ ] ) : any {
43
43
if ( conditions . includes ( false ) ) {
44
44
// always false
45
- // TODO: custom id field
46
- return { id : { in : [ ] } } ;
45
+ return { [ GUARD_FIELD_NAME ] : false } ;
47
46
}
48
47
49
48
const filtered = conditions . filter (
@@ -64,7 +63,7 @@ export class PolicyUtil {
64
63
or ( ...conditions : ( boolean | object ) [ ] ) : any {
65
64
if ( conditions . includes ( true ) ) {
66
65
// always true
67
- return { id : { notIn : [ ] } } ;
66
+ return { [ GUARD_FIELD_NAME ] : true } ;
68
67
}
69
68
70
69
const filtered = conditions . filter ( ( c ) : c is object => typeof c === 'object' && ! ! c ) ;
@@ -276,7 +275,7 @@ export class PolicyUtil {
276
275
return ;
277
276
}
278
277
279
- const idField = this . getIdField ( model ) ;
278
+ const idFields = this . getIdFields ( model ) ;
280
279
for ( const field of getModelFields ( injectTarget ) ) {
281
280
const fieldInfo = resolveField ( this . modelMeta , model , field ) ;
282
281
if ( ! fieldInfo || ! fieldInfo . isDataModel ) {
@@ -292,10 +291,16 @@ export class PolicyUtil {
292
291
293
292
await this . injectAuthGuard ( injectTarget [ field ] , fieldInfo . type , 'read' ) ;
294
293
} else {
295
- // there's no way of injecting condition for to-one relation, so we
296
- // make sure 'id' field is selected and check them against query result
297
- if ( injectTarget [ field ] ?. select && injectTarget [ field ] ?. select ?. [ idField . name ] !== true ) {
298
- injectTarget [ field ] . select [ idField . name ] = true ;
294
+ // there's no way of injecting condition for to-one relation, so if there's
295
+ // "select" clause we make sure 'id' fields are selected and check them against
296
+ // query result; nothing needs to be done for "include" clause because all
297
+ // fields are already selected
298
+ if ( injectTarget [ field ] ?. select ) {
299
+ for ( const idField of idFields ) {
300
+ if ( injectTarget [ field ] . select [ idField . name ] !== true ) {
301
+ injectTarget [ field ] . select [ idField . name ] = true ;
302
+ }
303
+ }
299
304
}
300
305
}
301
306
@@ -310,7 +315,8 @@ export class PolicyUtil {
310
315
* omitted.
311
316
*/
312
317
async postProcessForRead ( entityData : any , model : string , args : any , operation : PolicyOperationKind ) {
313
- if ( ! this . getEntityId ( model , entityData ) ) {
318
+ const ids = this . getEntityIds ( model , entityData ) ;
319
+ if ( Object . keys ( ids ) . length === 0 ) {
314
320
return ;
315
321
}
316
322
@@ -330,21 +336,23 @@ export class PolicyUtil {
330
336
// post-check them
331
337
332
338
for ( const field of getModelFields ( injectTarget ) ) {
339
+ if ( ! entityData ?. [ field ] ) {
340
+ continue ;
341
+ }
342
+
333
343
const fieldInfo = resolveField ( this . modelMeta , model , field ) ;
334
344
if ( ! fieldInfo || ! fieldInfo . isDataModel || fieldInfo . isArray ) {
335
345
continue ;
336
346
}
337
347
338
- const idField = this . getIdField ( fieldInfo . type ) ;
339
- const relatedEntityId = entityData ?. [ field ] ?. [ idField . name ] ;
348
+ const ids = this . getEntityIds ( fieldInfo . type , entityData [ field ] ) ;
340
349
341
- if ( ! relatedEntityId ) {
350
+ if ( Object . keys ( ids ) . length === 0 ) {
342
351
continue ;
343
352
}
344
353
345
- this . logger . info ( `Validating read of to-one relation: ${ fieldInfo . type } #${ relatedEntityId } ` ) ;
346
-
347
- await this . checkPolicyForFilter ( fieldInfo . type , { [ idField . name ] : relatedEntityId } , operation , this . db ) ;
354
+ this . logger . info ( `Validating read of to-one relation: ${ fieldInfo . type } #${ formatObject ( ids ) } ` ) ;
355
+ await this . checkPolicyForFilter ( fieldInfo . type , ids , operation , this . db ) ;
348
356
349
357
// recurse
350
358
await this . postProcessForRead ( entityData [ field ] , fieldInfo . type , injectTarget [ field ] , operation ) ;
@@ -366,14 +374,18 @@ export class PolicyUtil {
366
374
367
375
// record model entities that are updated, together with their
368
376
// values before update, so we can post-check if they satisfy
369
- // model => id => entity value
370
- const updatedModels = new Map < string , Map < string , any > > ( ) ;
377
+ // model => { ids, entity value }
378
+ const updatedModels = new Map < string , Array < { ids : Record < string , unknown > ; value : any } > > ( ) ;
371
379
372
- const idField = this . getIdField ( model ) ;
373
- if ( args . select && ! args . select [ idField . name ] ) {
380
+ const idFields = this . getIdFields ( model ) ;
381
+ if ( args . select ) {
374
382
// make sure 'id' field is selected, we need it to
375
383
// read back the updated entity
376
- args . select [ idField . name ] = true ;
384
+ for ( const idField of idFields ) {
385
+ if ( ! args . select [ idField . name ] ) {
386
+ args . select [ idField . name ] = true ;
387
+ }
388
+ }
377
389
}
378
390
379
391
// use a transaction to conduct write, so in case any create or nested create
@@ -496,7 +508,7 @@ export class PolicyUtil {
496
508
if ( postGuard !== true || schema ) {
497
509
let modelEntities = updatedModels . get ( model ) ;
498
510
if ( ! modelEntities ) {
499
- modelEntities = new Map < string , any > ( ) ;
511
+ modelEntities = [ ] ;
500
512
updatedModels . set ( model , modelEntities ) ;
501
513
}
502
514
@@ -509,11 +521,19 @@ export class PolicyUtil {
509
521
// e.g.: { a_b: { a: '1', b: '1' } } => { a: '1', b: '1' }
510
522
await this . flattenGeneratedUniqueField ( model , filter ) ;
511
523
512
- const idField = this . getIdField ( model ) ;
513
- const query = { where : filter , select : { ...preValueSelect , [ idField . name ] : true } } ;
524
+ const idFields = this . getIdFields ( model ) ;
525
+ // eslint-disable-next-line @typescript-eslint/no-unused-vars
526
+ const select : any = { ...preValueSelect } ;
527
+ for ( const idField of idFields ) {
528
+ select [ idField . name ] = true ;
529
+ }
530
+
531
+ const query = { where : filter , select } ;
514
532
this . logger . info ( `fetching pre-update entities for ${ model } : ${ formatObject ( query ) } )}` ) ;
515
533
const entities = await this . db [ model ] . findMany ( query ) ;
516
- entities . forEach ( ( entity ) => modelEntities ?. set ( this . getEntityId ( model , entity ) , entity ) ) ;
534
+ entities . forEach ( ( entity ) =>
535
+ modelEntities ?. push ( { ids : this . getEntityIds ( model , entity ) , value : entity } )
536
+ ) ;
517
537
}
518
538
} ;
519
539
@@ -622,8 +642,8 @@ export class PolicyUtil {
622
642
await Promise . all (
623
643
[ ...updatedModels . entries ( ) ]
624
644
. map ( ( [ model , modelEntities ] ) =>
625
- [ ... modelEntities . entries ( ) ] . map ( async ( [ id , preValue ] ) =>
626
- this . checkPostUpdate ( model , id , tx , preValue )
645
+ modelEntities . map ( async ( { ids , value : preValue } ) =>
646
+ this . checkPostUpdate ( model , ids , tx , preValue )
627
647
)
628
648
)
629
649
. flat ( )
@@ -716,14 +736,18 @@ export class PolicyUtil {
716
736
}
717
737
}
718
738
719
- private async checkPostUpdate ( model : string , id : any , db : Record < string , DbOperations > , preValue : any ) {
720
- this . logger . info ( `Checking post-update policy for ${ model } #${ id } , preValue: ${ formatObject ( preValue ) } ` ) ;
739
+ private async checkPostUpdate (
740
+ model : string ,
741
+ ids : Record < string , unknown > ,
742
+ db : Record < string , DbOperations > ,
743
+ preValue : any
744
+ ) {
745
+ this . logger . info ( `Checking post-update policy for ${ model } #${ ids } , preValue: ${ formatObject ( preValue ) } ` ) ;
721
746
722
747
const guard = await this . getAuthGuard ( model , 'postUpdate' , preValue ) ;
723
748
724
749
// build a query condition with policy injected
725
- const idField = this . getIdField ( model ) ;
726
- const guardedQuery = { where : this . and ( { [ idField . name ] : id } , guard ) } ;
750
+ const guardedQuery = { where : this . and ( ids , guard ) } ;
727
751
728
752
// query with policy injected
729
753
const entity = await db [ model ] . findFirst ( guardedQuery ) ;
@@ -760,13 +784,13 @@ export class PolicyUtil {
760
784
/**
761
785
* Gets "id" field for a given model.
762
786
*/
763
- getIdField ( model : string ) {
787
+ getIdFields ( model : string ) {
764
788
const fields = this . modelMeta . fields [ camelCase ( model ) ] ;
765
789
if ( ! fields ) {
766
790
throw this . unknownError ( `Unable to load fields for ${ model } ` ) ;
767
791
}
768
- const result = Object . values ( fields ) . find ( ( f ) => f . isId ) ;
769
- if ( ! result ) {
792
+ const result = Object . values ( fields ) . filter ( ( f ) => f . isId ) ;
793
+ if ( result . length === 0 ) {
770
794
throw this . unknownError ( `model ${ model } does not have an id field` ) ;
771
795
}
772
796
return result ;
@@ -775,8 +799,12 @@ export class PolicyUtil {
775
799
/**
776
800
* Gets id field value from an entity.
777
801
*/
778
- getEntityId ( model : string , entityData : any ) {
779
- const idField = this . getIdField ( model ) ;
780
- return entityData [ idField . name ] ;
802
+ getEntityIds ( model : string , entityData : any ) {
803
+ const idFields = this . getIdFields ( model ) ;
804
+ const result : Record < string , unknown > = { } ;
805
+ for ( const idField of idFields ) {
806
+ result [ idField . name ] = entityData [ idField . name ] ;
807
+ }
808
+ return result ;
781
809
}
782
810
}
0 commit comments