11/* eslint-disable @typescript-eslint/no-explicit-any */
22
3+ import deepmerge from 'deepmerge' ;
34import { lowerCaseFirst } from 'lower-case-first' ;
45import invariant from 'tiny-invariant' ;
56import { P , match } from 'ts-pattern' ;
@@ -23,7 +24,7 @@ import { Logger } from '../logger';
2324import { createDeferredPromise , createFluentPromise } from '../promise' ;
2425import { PrismaProxyHandler } from '../proxy' ;
2526import { QueryUtils } from '../query-utils' ;
26- import type { CheckerConstraint } from '../types' ;
27+ import type { AdditionalCheckerFunc , CheckerConstraint } from '../types' ;
2728import { clone , formatObject , isUnsafeMutate , prismaClientValidationError } from '../utils' ;
2829import { ConstraintSolver } from './constraint-solver' ;
2930import { PolicyUtil } from './policy-utils' ;
@@ -152,8 +153,7 @@ export class PolicyProxyHandler<DbClient extends DbClientContract> implements Pr
152153 }
153154
154155 const result = await this . modelClient [ actionName ] ( _args ) ;
155- this . policyUtils . postProcessForRead ( result , this . model , origArgs ) ;
156- return result ;
156+ return this . policyUtils . postProcessForRead ( result , this . model , origArgs ) ;
157157 }
158158
159159 //#endregion
@@ -779,10 +779,27 @@ export class PolicyProxyHandler<DbClient extends DbClientContract> implements Pr
779779 }
780780 } ;
781781
782- const _connectDisconnect = async ( model : string , args : any , context : NestedWriteVisitorContext ) => {
782+ const _connectDisconnect = async (
783+ model : string ,
784+ args : any ,
785+ context : NestedWriteVisitorContext ,
786+ operation : 'connect' | 'disconnect'
787+ ) => {
783788 if ( context . field ?. backLink ) {
784789 const backLinkField = this . policyUtils . getModelField ( model , context . field . backLink ) ;
785790 if ( backLinkField ?. isRelationOwner ) {
791+ let uniqueFilter = args ;
792+ if ( operation === 'disconnect' ) {
793+ // disconnect filter is not unique, need to build a reversed query to
794+ // locate the entity and use its id fields as unique filter
795+ const reversedQuery = this . policyUtils . buildReversedQuery ( context ) ;
796+ const found = await db [ model ] . findUnique ( {
797+ where : reversedQuery ,
798+ select : this . policyUtils . makeIdSelection ( model ) ,
799+ } ) ;
800+ uniqueFilter = found && this . policyUtils . getIdFieldValues ( model , found ) ;
801+ }
802+
786803 // update happens on the related model, require updatable,
787804 // translate args to foreign keys so field-level policies can be checked
788805 const checkArgs : any = { } ;
@@ -794,10 +811,15 @@ export class PolicyProxyHandler<DbClient extends DbClientContract> implements Pr
794811 }
795812 }
796813 }
797- await this . policyUtils . checkPolicyForUnique ( model , args , 'update' , db , checkArgs ) ;
798814
799- // register post-update check
800- await _registerPostUpdateCheck ( model , args , args ) ;
815+ // `uniqueFilter` can be undefined if the entity to be disconnected doesn't exist
816+ if ( uniqueFilter ) {
817+ // check for update
818+ await this . policyUtils . checkPolicyForUnique ( model , uniqueFilter , 'update' , db , checkArgs ) ;
819+
820+ // register post-update check
821+ await _registerPostUpdateCheck ( model , uniqueFilter , uniqueFilter ) ;
822+ }
801823 }
802824 }
803825 } ;
@@ -970,14 +992,14 @@ export class PolicyProxyHandler<DbClient extends DbClientContract> implements Pr
970992 }
971993 } ,
972994
973- connect : async ( model , args , context ) => _connectDisconnect ( model , args , context ) ,
995+ connect : async ( model , args , context ) => _connectDisconnect ( model , args , context , 'connect' ) ,
974996
975997 connectOrCreate : async ( model , args , context ) => {
976998 // the where condition is already unique, so we can use it to check if the target exists
977999 const existing = await this . policyUtils . checkExistence ( db , model , args . where ) ;
9781000 if ( existing ) {
9791001 // connect
980- await _connectDisconnect ( model , args . where , context ) ;
1002+ await _connectDisconnect ( model , args . where , context , 'connect' ) ;
9811003 return true ;
9821004 } else {
9831005 // create
@@ -997,7 +1019,7 @@ export class PolicyProxyHandler<DbClient extends DbClientContract> implements Pr
9971019 }
9981020 } ,
9991021
1000- disconnect : async ( model , args , context ) => _connectDisconnect ( model , args , context ) ,
1022+ disconnect : async ( model , args , context ) => _connectDisconnect ( model , args , context , 'disconnect' ) ,
10011023
10021024 set : async ( model , args , context ) => {
10031025 // find the set of items to be replaced
@@ -1012,10 +1034,10 @@ export class PolicyProxyHandler<DbClient extends DbClientContract> implements Pr
10121034 const currentSet = await db [ model ] . findMany ( findCurrSetArgs ) ;
10131035
10141036 // register current set for update (foreign key)
1015- await Promise . all ( currentSet . map ( ( item ) => _connectDisconnect ( model , item , context ) ) ) ;
1037+ await Promise . all ( currentSet . map ( ( item ) => _connectDisconnect ( model , item , context , 'disconnect' ) ) ) ;
10161038
10171039 // proceed with connecting the new set
1018- await Promise . all ( enumerate ( args ) . map ( ( item ) => _connectDisconnect ( model , item , context ) ) ) ;
1040+ await Promise . all ( enumerate ( args ) . map ( ( item ) => _connectDisconnect ( model , item , context , 'connect' ) ) ) ;
10191041 } ,
10201042
10211043 delete : async ( model , args , context ) => {
@@ -1160,48 +1182,78 @@ export class PolicyProxyHandler<DbClient extends DbClientContract> implements Pr
11601182
11611183 args . data = this . validateUpdateInputSchema ( this . model , args . data ) ;
11621184
1163- if ( this . policyUtils . hasAuthGuard ( this . model , 'postUpdate' ) || this . policyUtils . getZodSchema ( this . model ) ) {
1164- // use a transaction to do post-update checks
1165- const postWriteChecks : PostWriteCheckRecord [ ] = [ ] ;
1166- return this . queryUtils . transaction ( this . prisma , async ( tx ) => {
1167- // collect pre-update values
1168- let select = this . policyUtils . makeIdSelection ( this . model ) ;
1169- const preValueSelect = this . policyUtils . getPreValueSelect ( this . model ) ;
1170- if ( preValueSelect ) {
1171- select = { ...select , ...preValueSelect } ;
1172- }
1173- const currentSetQuery = { select, where : args . where } ;
1174- this . policyUtils . injectAuthGuardAsWhere ( tx , currentSetQuery , this . model , 'read' ) ;
1185+ const additionalChecker = this . policyUtils . getAdditionalChecker ( this . model , 'update' ) ;
11751186
1176- if ( this . shouldLogQuery ) {
1177- this . logger . info ( `[policy] \`findMany\` ${ this . model } : ${ formatObject ( currentSetQuery ) } ` ) ;
1178- }
1179- const currentSet = await tx [ this . model ] . findMany ( currentSetQuery ) ;
1187+ const canProceedWithoutTransaction =
1188+ // no post-update rules
1189+ ! this . policyUtils . hasAuthGuard ( this . model , 'postUpdate' ) &&
1190+ // no Zod schema
1191+ ! this . policyUtils . getZodSchema ( this . model ) &&
1192+ // no additional checker
1193+ ! additionalChecker ;
11801194
1181- postWriteChecks . push (
1182- ...currentSet . map ( ( preValue ) => ( {
1183- model : this . model ,
1184- operation : 'postUpdate' as PolicyOperationKind ,
1185- uniqueFilter : this . policyUtils . getEntityIds ( this . model , preValue ) ,
1186- preValue : preValueSelect ? preValue : undefined ,
1187- } ) )
1188- ) ;
1189-
1190- // proceed with the update
1191- const result = await tx [ this . model ] . updateMany ( args ) ;
1192-
1193- // run post-write checks
1194- await this . runPostWriteChecks ( postWriteChecks , tx ) ;
1195-
1196- return result ;
1197- } ) ;
1198- } else {
1195+ if ( canProceedWithoutTransaction ) {
11991196 // proceed without a transaction
12001197 if ( this . shouldLogQuery ) {
12011198 this . logger . info ( `[policy] \`updateMany\` ${ this . model } : ${ formatObject ( args ) } ` ) ;
12021199 }
12031200 return this . modelClient . updateMany ( args ) ;
12041201 }
1202+
1203+ // collect post-update checks
1204+ const postWriteChecks : PostWriteCheckRecord [ ] = [ ] ;
1205+
1206+ return this . queryUtils . transaction ( this . prisma , async ( tx ) => {
1207+ // collect pre-update values
1208+ let select = this . policyUtils . makeIdSelection ( this . model ) ;
1209+ const preValueSelect = this . policyUtils . getPreValueSelect ( this . model ) ;
1210+ if ( preValueSelect ) {
1211+ select = { ...select , ...preValueSelect } ;
1212+ }
1213+
1214+ // merge selection required for running additional checker
1215+ const additionalCheckerSelector = this . policyUtils . getAdditionalCheckerSelector ( this . model , 'update' ) ;
1216+ if ( additionalCheckerSelector ) {
1217+ select = deepmerge ( select , additionalCheckerSelector ) ;
1218+ }
1219+
1220+ const currentSetQuery = { select, where : args . where } ;
1221+ this . policyUtils . injectAuthGuardAsWhere ( tx , currentSetQuery , this . model , 'update' ) ;
1222+
1223+ if ( this . shouldLogQuery ) {
1224+ this . logger . info ( `[policy] \`findMany\` ${ this . model } : ${ formatObject ( currentSetQuery ) } ` ) ;
1225+ }
1226+ let candidates = await tx [ this . model ] . findMany ( currentSetQuery ) ;
1227+
1228+ if ( additionalChecker ) {
1229+ // filter candidates with additional checker and build an id filter
1230+ const r = this . buildIdFilterWithAdditionalChecker ( candidates , additionalChecker ) ;
1231+ candidates = r . filteredCandidates ;
1232+
1233+ // merge id filter into update's where clause
1234+ args . where = args . where ? { AND : [ args . where , r . idFilter ] } : r . idFilter ;
1235+ }
1236+
1237+ postWriteChecks . push (
1238+ ...candidates . map ( ( preValue ) => ( {
1239+ model : this . model ,
1240+ operation : 'postUpdate' as PolicyOperationKind ,
1241+ uniqueFilter : this . policyUtils . getEntityIds ( this . model , preValue ) ,
1242+ preValue : preValueSelect ? preValue : undefined ,
1243+ } ) )
1244+ ) ;
1245+
1246+ // proceed with the update
1247+ if ( this . shouldLogQuery ) {
1248+ this . logger . info ( `[policy] \`updateMany\` in tx for ${ this . model } : ${ formatObject ( args ) } ` ) ;
1249+ }
1250+ const result = await tx [ this . model ] . updateMany ( args ) ;
1251+
1252+ // run post-write checks
1253+ await this . runPostWriteChecks ( postWriteChecks , tx ) ;
1254+
1255+ return result ;
1256+ } ) ;
12051257 } ) ;
12061258 }
12071259
@@ -1328,14 +1380,53 @@ export class PolicyProxyHandler<DbClient extends DbClientContract> implements Pr
13281380 this . policyUtils . tryReject ( this . prisma , this . model , 'delete' ) ;
13291381
13301382 // inject policy conditions
1331- args = args ?? { } ;
1383+ args = clone ( args ) ;
13321384 this . policyUtils . injectAuthGuardAsWhere ( this . prisma , args , this . model , 'delete' ) ;
13331385
1334- // conduct the deletion
1335- if ( this . shouldLogQuery ) {
1336- this . logger . info ( `[policy] \`deleteMany\` ${ this . model } :\n${ formatObject ( args ) } ` ) ;
1386+ const additionalChecker = this . policyUtils . getAdditionalChecker ( this . model , 'delete' ) ;
1387+ if ( additionalChecker ) {
1388+ // additional checker exists, need to run deletion inside a transaction
1389+ return this . queryUtils . transaction ( this . prisma , async ( tx ) => {
1390+ // find the delete candidates, selecting id fields and fields needed for
1391+ // running the additional checker
1392+ let candidateSelect = this . policyUtils . makeIdSelection ( this . model ) ;
1393+ const additionalCheckerSelector = this . policyUtils . getAdditionalCheckerSelector (
1394+ this . model ,
1395+ 'delete'
1396+ ) ;
1397+ if ( additionalCheckerSelector ) {
1398+ candidateSelect = deepmerge ( candidateSelect , additionalCheckerSelector ) ;
1399+ }
1400+
1401+ if ( this . shouldLogQuery ) {
1402+ this . logger . info (
1403+ `[policy] \`findMany\` ${ this . model } : ${ formatObject ( {
1404+ where : args . where ,
1405+ select : candidateSelect ,
1406+ } ) } `
1407+ ) ;
1408+ }
1409+ const candidates = await tx [ this . model ] . findMany ( { where : args . where , select : candidateSelect } ) ;
1410+
1411+ // build a ID filter based on id values filtered by the additional checker
1412+ const { idFilter } = this . buildIdFilterWithAdditionalChecker ( candidates , additionalChecker ) ;
1413+
1414+ // merge the ID filter into the where clause
1415+ args . where = args . where ? { AND : [ args . where , idFilter ] } : idFilter ;
1416+
1417+ // finally, conduct the deletion with the combined where clause
1418+ if ( this . shouldLogQuery ) {
1419+ this . logger . info ( `[policy] \`deleteMany\` in tx for ${ this . model } :\n${ formatObject ( args ) } ` ) ;
1420+ }
1421+ return tx [ this . model ] . deleteMany ( args ) ;
1422+ } ) ;
1423+ } else {
1424+ // conduct the deletion directly
1425+ if ( this . shouldLogQuery ) {
1426+ this . logger . info ( `[policy] \`deleteMany\` ${ this . model } :\n${ formatObject ( args ) } ` ) ;
1427+ }
1428+ return this . modelClient . deleteMany ( args ) ;
13371429 }
1338- return this . modelClient . deleteMany ( args ) ;
13391430 } ) ;
13401431 }
13411432
@@ -1599,5 +1690,17 @@ export class PolicyProxyHandler<DbClient extends DbClientContract> implements Pr
15991690 }
16001691 }
16011692
1693+ private buildIdFilterWithAdditionalChecker ( candidates : any [ ] , additionalChecker : AdditionalCheckerFunc ) {
1694+ const filteredCandidates = candidates . filter ( ( value ) => additionalChecker ( { user : this . context ?. user } , value ) ) ;
1695+ const idFields = this . policyUtils . getIdFields ( this . model ) ;
1696+ let idFilter : any ;
1697+ if ( idFields . length === 1 ) {
1698+ idFilter = { [ idFields [ 0 ] . name ] : { in : filteredCandidates . map ( ( x ) => x [ idFields [ 0 ] . name ] ) } } ;
1699+ } else {
1700+ idFilter = { AND : filteredCandidates . map ( ( x ) => this . policyUtils . getIdFieldValues ( this . model , x ) ) } ;
1701+ }
1702+ return { filteredCandidates, idFilter } ;
1703+ }
1704+
16021705 //#endregion
16031706}
0 commit comments