@@ -420,7 +420,7 @@ export class PolicyProxyHandler<DbClient extends DbClientContract> implements Pr
420420        } ) ; 
421421
422422        // return only the ids of the top-level entity 
423-         const  ids  =  this . utils . getEntityIds ( this . model ,  result ) ; 
423+         const  ids  =  this . utils . getEntityIds ( model ,  result ) ; 
424424        return  {  result : ids ,  postWriteChecks : [ ...postCreateChecks . values ( ) ]  } ; 
425425    } 
426426
@@ -792,8 +792,10 @@ export class PolicyProxyHandler<DbClient extends DbClientContract> implements Pr
792792            } 
793793
794794            // proceed with the create and collect post-create checks 
795-             const  {  postWriteChecks : checks  }  =  await  this . doCreate ( model ,  {  data : createData  } ,  db ) ; 
795+             const  {  postWriteChecks : checks ,  result  }  =  await  this . doCreate ( model ,  {  data : createData  } ,  db ) ; 
796796            postWriteChecks . push ( ...checks ) ; 
797+ 
798+             return  result ; 
797799        } ; 
798800
799801        const  _createMany  =  async  ( 
@@ -881,18 +883,10 @@ export class PolicyProxyHandler<DbClient extends DbClientContract> implements Pr
881883                    // check pre-update guard 
882884                    await  this . utils . checkPolicyForUnique ( model ,  uniqueFilter ,  'update' ,  db ,  args ) ; 
883885
884-                     // handles the case where id fields are updated 
885-                     const  postUpdateIds  =  this . utils . clone ( existing ) ; 
886-                     for  ( const  key  of  Object . keys ( existing ) )  { 
887-                         const  updateValue  =  ( args  as  any ) . data  ? ( args  as  any ) . data [ key ]  : ( args  as  any ) [ key ] ; 
888-                         if  ( 
889-                             typeof  updateValue  ===  'string'  || 
890-                             typeof  updateValue  ===  'number'  || 
891-                             typeof  updateValue  ===  'bigint' 
892-                         )  { 
893-                             postUpdateIds [ key ]  =  updateValue ; 
894-                         } 
895-                     } 
886+                     // handle the case where id fields are updated 
887+                     const  _args : any  =  args ; 
888+                     const  updatePayload  =  _args . data  &&  typeof  _args . data  ===  'object'  ? _args . data  : _args ; 
889+                     const  postUpdateIds  =  this . calculatePostUpdateIds ( model ,  existing ,  updatePayload ) ; 
896890
897891                    // register post-update check 
898892                    await  _registerPostUpdateCheck ( model ,  existing ,  postUpdateIds ) ; 
@@ -984,10 +978,13 @@ export class PolicyProxyHandler<DbClient extends DbClientContract> implements Pr
984978                    // update case 
985979
986980                    // check pre-update guard 
987-                     await  this . utils . checkPolicyForUnique ( model ,  uniqueFilter ,  'update' ,  db ,  args ) ; 
981+                     await  this . utils . checkPolicyForUnique ( model ,  existing ,  'update' ,  db ,  args ) ; 
982+ 
983+                     // handle the case where id fields are updated 
984+                     const  postUpdateIds  =  this . calculatePostUpdateIds ( model ,  existing ,  args . update ) ; 
988985
989986                    // register post-update check 
990-                     await  _registerPostUpdateCheck ( model ,  uniqueFilter ,   uniqueFilter ) ; 
987+                     await  _registerPostUpdateCheck ( model ,  existing ,   postUpdateIds ) ; 
991988
992989                    // convert upsert to update 
993990                    const  convertedUpdate  =  { 
@@ -1021,9 +1018,22 @@ export class PolicyProxyHandler<DbClient extends DbClientContract> implements Pr
10211018                if  ( existing )  { 
10221019                    // connect 
10231020                    await  _connectDisconnect ( model ,  args . where ,  context ) ; 
1021+                     return  true ; 
10241022                }  else  { 
10251023                    // create 
1026-                     await  _create ( model ,  args . create ,  context ) ; 
1024+                     const  created  =  await  _create ( model ,  args . create ,  context ) ; 
1025+ 
1026+                     const  upperContext  =  context . nestingPath [ context . nestingPath . length  -  2 ] ; 
1027+                     if  ( upperContext ?. where  &&  context . field )  { 
1028+                         // check if the where clause of the upper context references the id 
1029+                         // of the connected entity, if so, we need to update it 
1030+                         this . overrideForeignKeyFields ( upperContext . model ,  upperContext . where ,  context . field ,  created ) ; 
1031+                     } 
1032+ 
1033+                     // remove the payload from the parent 
1034+                     this . removeFromParent ( context . parent ,  'connectOrCreate' ,  args ) ; 
1035+ 
1036+                     return  false ; 
10271037                } 
10281038            } , 
10291039
@@ -1093,6 +1103,52 @@ export class PolicyProxyHandler<DbClient extends DbClientContract> implements Pr
10931103        return  {  result,  postWriteChecks } ; 
10941104    } 
10951105
1106+     // calculate id fields used for post-update check given an update payload 
1107+     private  calculatePostUpdateIds ( _model : string ,  currentIds : any ,  updatePayload : any )  { 
1108+         const  result  =  this . utils . clone ( currentIds ) ; 
1109+         for  ( const  key  of  Object . keys ( currentIds ) )  { 
1110+             const  updateValue  =  updatePayload [ key ] ; 
1111+             if  ( typeof  updateValue  ===  'string'  ||  typeof  updateValue  ===  'number'  ||  typeof  updateValue  ===  'bigint' )  { 
1112+                 result [ key ]  =  updateValue ; 
1113+             } 
1114+         } 
1115+         return  result ; 
1116+     } 
1117+ 
1118+     // updates foreign key fields inside `payload` based on relation id fields in `newIds` 
1119+     private  overrideForeignKeyFields ( 
1120+         model : string , 
1121+         payload : any , 
1122+         relation : FieldInfo , 
1123+         newIds : Record < string ,  unknown > 
1124+     )  { 
1125+         if  ( ! relation . foreignKeyMapping  ||  Object . keys ( relation . foreignKeyMapping ) . length  ===  0 )  { 
1126+             return ; 
1127+         } 
1128+ 
1129+         // override foreign key values 
1130+         for  ( const  [ id ,  fk ]  of  Object . entries ( relation . foreignKeyMapping ) )  { 
1131+             if  ( payload [ fk ]  !==  undefined  &&  newIds [ id ]  !==  undefined )  { 
1132+                 payload [ fk ]  =  newIds [ id ] ; 
1133+             } 
1134+         } 
1135+ 
1136+         // deal with compound id fields 
1137+         const  uniqueConstraints  =  this . utils . getUniqueConstraints ( model ) ; 
1138+         for  ( const  [ name ,  constraint ]  of  Object . entries ( uniqueConstraints ) )  { 
1139+             if  ( constraint . fields . length  >  1 )  { 
1140+                 const  target  =  payload [ name ] ; 
1141+                 if  ( target )  { 
1142+                     for  ( const  [ id ,  fk ]  of  Object . entries ( relation . foreignKeyMapping ) )  { 
1143+                         if  ( target [ fk ]  !==  undefined  &&  newIds [ id ]  !==  undefined )  { 
1144+                             target [ fk ]  =  newIds [ id ] ; 
1145+                         } 
1146+                     } 
1147+                 } 
1148+             } 
1149+         } 
1150+     } 
1151+ 
10961152    // Validates the given update payload against Zod schema if any 
10971153    private  validateUpdateInputSchema ( model : string ,  data : any )  { 
10981154        const  schema  =  this . utils . getZodSchema ( model ,  'update' ) ; 
@@ -1224,11 +1280,18 @@ export class PolicyProxyHandler<DbClient extends DbClientContract> implements Pr
12241280
12251281        const  {  result,  error }  =  await  this . transaction ( async  ( tx )  =>  { 
12261282            const  {  where,  create,  update,  ...rest  }  =  args ; 
1227-             const  existing  =  await  this . utils . checkExistence ( tx ,  this . model ,  args . where ) ; 
1283+             const  existing  =  await  this . utils . checkExistence ( tx ,  this . model ,  where ) ; 
12281284
12291285            if  ( existing )  { 
12301286                // update case 
1231-                 const  {  result,  postWriteChecks }  =  await  this . doUpdate ( {  where,  data : update ,  ...rest  } ,  tx ) ; 
1287+                 const  {  result,  postWriteChecks }  =  await  this . doUpdate ( 
1288+                     { 
1289+                         where : this . utils . composeCompoundUniqueField ( this . model ,  existing ) , 
1290+                         data : update , 
1291+                         ...rest , 
1292+                     } , 
1293+                     tx 
1294+                 ) ; 
12321295                await  this . runPostWriteChecks ( postWriteChecks ,  tx ) ; 
12331296                return  this . utils . readBack ( tx ,  this . model ,  'update' ,  args ,  result ) ; 
12341297            }  else  { 
0 commit comments