77    getDataModelAndTypeDefs , 
88    getDataModels , 
99    getLiteral , 
10+     getRelationField , 
1011    isDelegateModel , 
1112    isDiscriminatorField , 
1213    normalizedRelative , 
@@ -55,12 +56,23 @@ type DelegateInfo = [DataModel, DataModel[]][];
5556const  LOGICAL_CLIENT_GENERATION_PATH  =  './.logical-prisma-client' ; 
5657
5758export  class  EnhancerGenerator  { 
59+     // regex for matching "ModelCreateXXXInput" and "ModelUncheckedCreateXXXInput" type 
60+     // names for models that use `auth()` in `@default` attribute 
61+     private  readonly  modelsWithAuthInDefaultCreateInputPattern : RegExp ; 
62+ 
5863    constructor ( 
5964        private  readonly  model : Model , 
6065        private  readonly  options : PluginOptions , 
6166        private  readonly  project : Project , 
6267        private  readonly  outDir : string 
63-     )  { } 
68+     )  { 
69+         const  modelsWithAuthInDefault  =  this . model . declarations . filter ( 
70+             ( d ) : d  is DataModel  =>  isDataModel ( d )  &&  d . fields . some ( ( f )  =>  f . attributes . some ( isDefaultWithAuth ) ) 
71+         ) ; 
72+         this . modelsWithAuthInDefaultCreateInputPattern  =  new  RegExp ( 
73+             `^(${ modelsWithAuthInDefault . map ( ( m )  =>  m . name ) . join ( '|' ) }  
74+         ) ; 
75+     } 
6476
6577    async  generate ( ) : Promise < {  dmmf : DMMF . Document  |  undefined ;  newPrismaClientDtsPath : string  |  undefined  } >  { 
6678        let  dmmf : DMMF . Document  |  undefined ; 
@@ -69,7 +81,7 @@ export class EnhancerGenerator {
6981        let  prismaTypesFixed  =  false ; 
7082        let  resultPrismaImport  =  prismaImport ; 
7183
72-         if  ( this . needsLogicalClient   ||   this . needsPrismaClientTypeFixes )  { 
84+         if  ( this . needsLogicalClient )  { 
7385            prismaTypesFixed  =  true ; 
7486            resultPrismaImport  =  `${ LOGICAL_CLIENT_GENERATION_PATH }  ; 
7587            const  result  =  await  this . generateLogicalPrisma ( ) ; 
@@ -230,11 +242,7 @@ export function enhance(prisma: any, context?: EnhancementContext<${authTypePara
230242    } 
231243
232244    private  get  needsLogicalClient ( )  { 
233-         return  this . hasDelegateModel ( this . model )  ||  this . hasAuthInDefault ( this . model ) ; 
234-     } 
235- 
236-     private  get  needsPrismaClientTypeFixes ( )  { 
237-         return  this . hasTypeDef ( this . model ) ; 
245+         return  this . hasDelegateModel ( this . model )  ||  this . hasAuthInDefault ( this . model )  ||  this . hasTypeDef ( this . model ) ; 
238246    } 
239247
240248    private  hasDelegateModel ( model : Model )  { 
@@ -449,11 +457,13 @@ export function enhance(prisma: any, context?: EnhancementContext<${authTypePara
449457        const  auxFields  =  this . findAuxDecls ( variable ) ; 
450458        if  ( auxFields . length  >  0 )  { 
451459            structure . declarations . forEach ( ( variable )  =>  { 
452-                 let  source  =  variable . type ?. toString ( ) ; 
453-                 auxFields . forEach ( ( f )  =>  { 
454-                     source  =  source ?. replace ( f . getText ( ) ,  '' ) ; 
455-                 } ) ; 
456-                 variable . type  =  source ; 
460+                 if  ( variable . type )  { 
461+                     let  source  =  variable . type . toString ( ) ; 
462+                     auxFields . forEach ( ( f )  =>  { 
463+                         source  =  this . removeFromSource ( source ,  f . getText ( ) ) ; 
464+                     } ) ; 
465+                     variable . type  =  source ; 
466+                 } 
457467            } ) ; 
458468        } 
459469
@@ -498,72 +508,16 @@ export function enhance(prisma: any, context?: EnhancementContext<${authTypePara
498508        // fix delegate payload union type 
499509        source  =  this . fixDelegatePayloadType ( typeAlias ,  delegateInfo ,  source ) ; 
500510
511+         // fix fk and relation fields related to using `auth()` in `@default` 
512+         source  =  this . fixDefaultAuthType ( typeAlias ,  source ) ; 
513+ 
501514        // fix json field type 
502515        source  =  this . fixJsonFieldType ( typeAlias ,  source ) ; 
503516
504517        structure . type  =  source ; 
505518        return  structure ; 
506519    } 
507520
508-     private  fixJsonFieldType ( typeAlias : TypeAliasDeclaration ,  source : string )  { 
509-         const  modelsWithTypeField  =  this . model . declarations . filter ( 
510-             ( d ) : d  is DataModel  =>  isDataModel ( d )  &&  d . fields . some ( ( f )  =>  isTypeDef ( f . type . reference ?. ref ) ) 
511-         ) ; 
512-         const  typeName  =  typeAlias . getName ( ) ; 
513- 
514-         const  getTypedJsonFields  =  ( model : DataModel )  =>  { 
515-             return  model . fields . filter ( ( f )  =>  isTypeDef ( f . type . reference ?. ref ) ) ; 
516-         } ; 
517- 
518-         const  replacePrismaJson  =  ( source : string ,  field : DataModelField )  =>  { 
519-             return  source . replace ( 
520-                 new  RegExp ( `(${ field . name }  ) , 
521-                 `$1: ${ field . type . reference ! . $refText } ${ field . type . array  ? '[]'  : '' } ${  
522-                     field . type . optional  ? ' | null'  : ''  
523-                 }  `
524-             ) ; 
525-         } ; 
526- 
527-         // fix "$[Model]Payload" type 
528-         const  payloadModelMatch  =  modelsWithTypeField . find ( ( m )  =>  `$${ m . name }   ===  typeName ) ; 
529-         if  ( payloadModelMatch )  { 
530-             const  scalars  =  typeAlias 
531-                 . getDescendantsOfKind ( SyntaxKind . PropertySignature ) 
532-                 . find ( ( p )  =>  p . getName ( )  ===  'scalars' ) ; 
533-             if  ( ! scalars )  { 
534-                 return  source ; 
535-             } 
536- 
537-             const  fieldsToFix  =  getTypedJsonFields ( payloadModelMatch ) ; 
538-             for  ( const  field  of  fieldsToFix )  { 
539-                 source  =  replacePrismaJson ( source ,  field ) ; 
540-             } 
541-         } 
542- 
543-         // fix input/output types, "[Model]CreateInput", etc. 
544-         const  inputOutputModelMatch  =  modelsWithTypeField . find ( ( m )  =>  typeName . startsWith ( m . name ) ) ; 
545-         if  ( inputOutputModelMatch )  { 
546-             const  relevantTypePatterns  =  [ 
547-                 'GroupByOutputType' , 
548-                 '(Unchecked)?Create(\\S+?)?Input' , 
549-                 '(Unchecked)?Update(\\S+?)?Input' , 
550-                 'CreateManyInput' , 
551-                 '(Unchecked)?UpdateMany(Mutation)?Input' , 
552-             ] ; 
553-             const  typeRegex  =  modelsWithTypeField . map ( 
554-                 ( m )  =>  new  RegExp ( `^(${ m . name } ${ relevantTypePatterns . join ( '|' ) }  ) 
555-             ) ; 
556-             if  ( typeRegex . some ( ( r )  =>  r . test ( typeName ) ) )  { 
557-                 const  fieldsToFix  =  getTypedJsonFields ( inputOutputModelMatch ) ; 
558-                 for  ( const  field  of  fieldsToFix )  { 
559-                     source  =  replacePrismaJson ( source ,  field ) ; 
560-                 } 
561-             } 
562-         } 
563- 
564-         return  source ; 
565-     } 
566- 
567521    private  fixDelegatePayloadType ( typeAlias : TypeAliasDeclaration ,  delegateInfo : DelegateInfo ,  source : string )  { 
568522        // change the type of `$<DelegateModel>Payload` type of delegate model to a union of concrete types 
569523        const  typeName  =  typeAlias . getName ( ) ; 
@@ -595,7 +549,7 @@ export function enhance(prisma: any, context?: EnhancementContext<${authTypePara
595549                . getDescendantsOfKind ( SyntaxKind . PropertySignature ) 
596550                . filter ( ( p )  =>  [ 'create' ,  'createMany' ,  'connectOrCreate' ,  'upsert' ] . includes ( p . getName ( ) ) ) ; 
597551            toRemove . forEach ( ( r )  =>  { 
598-                 source   =   source . replace ( r . getText ( ) ,   '' ) ; 
552+                 this . removeFromSource ( source ,   r . getText ( ) ) ; 
599553            } ) ; 
600554        } 
601555        return  source ; 
@@ -633,7 +587,7 @@ export function enhance(prisma: any, context?: EnhancementContext<${authTypePara
633587                if  ( isDiscriminatorField ( field ) )  { 
634588                    const  fieldDef  =  this . findNamedProperty ( typeAlias ,  field . name ) ; 
635589                    if  ( fieldDef )  { 
636-                         source  =  source . replace ( fieldDef . getText ( ) ,   '' ) ; 
590+                         source  =  this . removeFromSource ( source ,   fieldDef . getText ( ) ) ; 
637591                    } 
638592                } 
639593            } 
@@ -646,7 +600,7 @@ export function enhance(prisma: any, context?: EnhancementContext<${authTypePara
646600        const  auxDecls  =  this . findAuxDecls ( typeAlias ) ; 
647601        if  ( auxDecls . length  >  0 )  { 
648602            auxDecls . forEach ( ( d )  =>  { 
649-                 source  =  source . replace ( d . getText ( ) ,   '' ) ; 
603+                 source  =  this . removeFromSource ( source ,   d . getText ( ) ) ; 
650604            } ) ; 
651605        } 
652606        return  source ; 
@@ -677,7 +631,7 @@ export function enhance(prisma: any, context?: EnhancementContext<${authTypePara
677631        const  fieldDef  =  this . findNamedProperty ( typeAlias ,  relationFieldName ) ; 
678632        if  ( fieldDef )  { 
679633            // remove relation field of delegate type, e.g., `asset` 
680-             source  =  source . replace ( fieldDef . getText ( ) ,   '' ) ; 
634+             source  =  this . removeFromSource ( source ,   fieldDef . getText ( ) ) ; 
681635        } 
682636
683637        // remove fk fields related to the delegate type relation, e.g., `assetId` 
@@ -709,13 +663,103 @@ export function enhance(prisma: any, context?: EnhancementContext<${authTypePara
709663        fkFields . forEach ( ( fkField )  =>  { 
710664            const  fieldDef  =  this . findNamedProperty ( typeAlias ,  fkField ) ; 
711665            if  ( fieldDef )  { 
712-                 source  =  source . replace ( fieldDef . getText ( ) ,   '' ) ; 
666+                 source  =  this . removeFromSource ( source ,   fieldDef . getText ( ) ) ; 
713667            } 
714668        } ) ; 
715669
716670        return  source ; 
717671    } 
718672
673+     private  fixDefaultAuthType ( typeAlias : TypeAliasDeclaration ,  source : string )  { 
674+         const  match  =  typeAlias . getName ( ) . match ( this . modelsWithAuthInDefaultCreateInputPattern ) ; 
675+         if  ( ! match )  { 
676+             return  source ; 
677+         } 
678+ 
679+         const  modelName  =  match [ 1 ] ; 
680+         const  dataModel  =  this . model . declarations . find ( ( d ) : d  is DataModel  =>  isDataModel ( d )  &&  d . name  ===  modelName ) ; 
681+         if  ( dataModel )  { 
682+             for  ( const  fkField  of  dataModel . fields . filter ( ( f )  =>  f . attributes . some ( isDefaultWithAuth ) ) )  { 
683+                 // change fk field to optional since it has a default 
684+                 source  =  source . replace ( new  RegExp ( `^(\\s*${ fkField . name }  ,  'm' ) ,  `$1?:` ) ; 
685+ 
686+                 const  relationField  =  getRelationField ( fkField ) ; 
687+                 if  ( relationField )  { 
688+                     // change relation field to optional since its fk has a default 
689+                     source  =  source . replace ( new  RegExp ( `^(\\s*${ relationField . name }  ,  'm' ) ,  `$1?:` ) ; 
690+                 } 
691+             } 
692+         } 
693+         return  source ; 
694+     } 
695+ 
696+     private  fixJsonFieldType ( typeAlias : TypeAliasDeclaration ,  source : string )  { 
697+         const  modelsWithTypeField  =  this . model . declarations . filter ( 
698+             ( d ) : d  is DataModel  =>  isDataModel ( d )  &&  d . fields . some ( ( f )  =>  isTypeDef ( f . type . reference ?. ref ) ) 
699+         ) ; 
700+         const  typeName  =  typeAlias . getName ( ) ; 
701+ 
702+         const  getTypedJsonFields  =  ( model : DataModel )  =>  { 
703+             return  model . fields . filter ( ( f )  =>  isTypeDef ( f . type . reference ?. ref ) ) ; 
704+         } ; 
705+ 
706+         const  replacePrismaJson  =  ( source : string ,  field : DataModelField )  =>  { 
707+             return  source . replace ( 
708+                 new  RegExp ( `(${ field . name }  ) , 
709+                 `$1: ${ field . type . reference ! . $refText } ${ field . type . array  ? '[]'  : '' } ${  
710+                     field . type . optional  ? ' | null'  : ''  
711+                 }  `
712+             ) ; 
713+         } ; 
714+ 
715+         // fix "$[Model]Payload" type 
716+         const  payloadModelMatch  =  modelsWithTypeField . find ( ( m )  =>  `$${ m . name }   ===  typeName ) ; 
717+         if  ( payloadModelMatch )  { 
718+             const  scalars  =  typeAlias 
719+                 . getDescendantsOfKind ( SyntaxKind . PropertySignature ) 
720+                 . find ( ( p )  =>  p . getName ( )  ===  'scalars' ) ; 
721+             if  ( ! scalars )  { 
722+                 return  source ; 
723+             } 
724+ 
725+             const  fieldsToFix  =  getTypedJsonFields ( payloadModelMatch ) ; 
726+             for  ( const  field  of  fieldsToFix )  { 
727+                 source  =  replacePrismaJson ( source ,  field ) ; 
728+             } 
729+         } 
730+ 
731+         // fix input/output types, "[Model]CreateInput", etc. 
732+         const  inputOutputModelMatch  =  modelsWithTypeField . find ( ( m )  =>  typeName . startsWith ( m . name ) ) ; 
733+         if  ( inputOutputModelMatch )  { 
734+             const  relevantTypePatterns  =  [ 
735+                 'GroupByOutputType' , 
736+                 '(Unchecked)?Create(\\S+?)?Input' , 
737+                 '(Unchecked)?Update(\\S+?)?Input' , 
738+                 'CreateManyInput' , 
739+                 '(Unchecked)?UpdateMany(Mutation)?Input' , 
740+             ] ; 
741+             const  typeRegex  =  modelsWithTypeField . map ( 
742+                 ( m )  =>  new  RegExp ( `^(${ m . name } ${ relevantTypePatterns . join ( '|' ) }  ) 
743+             ) ; 
744+             if  ( typeRegex . some ( ( r )  =>  r . test ( typeName ) ) )  { 
745+                 const  fieldsToFix  =  getTypedJsonFields ( inputOutputModelMatch ) ; 
746+                 for  ( const  field  of  fieldsToFix )  { 
747+                     source  =  replacePrismaJson ( source ,  field ) ; 
748+                 } 
749+             } 
750+         } 
751+ 
752+         return  source ; 
753+     } 
754+ 
755+     private  async  generateExtraTypes ( sf : SourceFile )  { 
756+         for  ( const  decl  of  this . model . declarations )  { 
757+             if  ( isTypeDef ( decl ) )  { 
758+                 generateTypeDefType ( sf ,  decl ) ; 
759+             } 
760+         } 
761+     } 
762+ 
719763    private  findNamedProperty ( typeAlias : TypeAliasDeclaration ,  name : string )  { 
720764        return  typeAlias . getFirstDescendant ( ( d )  =>  d . isKind ( SyntaxKind . PropertySignature )  &&  d . getName ( )  ===  name ) ; 
721765    } 
@@ -745,11 +789,12 @@ export function enhance(prisma: any, context?: EnhancementContext<${authTypePara
745789        return  this . options . generatePermissionChecker  ===  true ; 
746790    } 
747791
748-     private  async  generateExtraTypes ( sf : SourceFile )  { 
749-         for  ( const  decl  of  this . model . declarations )  { 
750-             if  ( isTypeDef ( decl ) )  { 
751-                 generateTypeDefType ( sf ,  decl ) ; 
752-             } 
753-         } 
792+     private  removeFromSource ( source : string ,  text : string )  { 
793+         source  =  source . replace ( text ,  '' ) ; 
794+         return  this . trimEmptyLines ( source ) ; 
795+     } 
796+ 
797+     private  trimEmptyLines ( source : string ) : string  { 
798+         return  source . replace ( / ^ \s * [ \r \n ] / gm,  '' ) ; 
754799    } 
755800} 
0 commit comments