77 getDataModelAndTypeDefs ,
88 getDataModels ,
99 getLiteral ,
10+ getRelationField ,
1011 isDelegateModel ,
1112 isDiscriminatorField ,
1213 normalizedRelative ,
@@ -55,12 +56,23 @@ type DelegateInfo = [DataModel, DataModel[]][];
5556const LOGICAL_CLIENT_GENERATION_PATH = './.logical-prisma-client' ;
5657
5758export class EnhancerGenerator {
59+ // regex for matching "ModelCreateXXXInput" and "ModelUncheckedCreateXXXInput" type
60+ // names for models that use `auth()` in `@default` attribute
61+ private readonly modelsWithAuthInDefaultCreateInputPattern : RegExp ;
62+
5863 constructor (
5964 private readonly model : Model ,
6065 private readonly options : PluginOptions ,
6166 private readonly project : Project ,
6267 private readonly outDir : string
63- ) { }
68+ ) {
69+ const modelsWithAuthInDefault = this . model . declarations . filter (
70+ ( d ) : d is DataModel => isDataModel ( d ) && d . fields . some ( ( f ) => f . attributes . some ( isDefaultWithAuth ) )
71+ ) ;
72+ this . modelsWithAuthInDefaultCreateInputPattern = new RegExp (
73+ `^(${ modelsWithAuthInDefault . map ( ( m ) => m . name ) . join ( '|' ) } )(Unchecked)?Create.*?Input$`
74+ ) ;
75+ }
6476
6577 async generate ( ) : Promise < { dmmf : DMMF . Document | undefined ; newPrismaClientDtsPath : string | undefined } > {
6678 let dmmf : DMMF . Document | undefined ;
@@ -69,7 +81,7 @@ export class EnhancerGenerator {
6981 let prismaTypesFixed = false ;
7082 let resultPrismaImport = prismaImport ;
7183
72- if ( this . needsLogicalClient || this . needsPrismaClientTypeFixes ) {
84+ if ( this . needsLogicalClient ) {
7385 prismaTypesFixed = true ;
7486 resultPrismaImport = `${ LOGICAL_CLIENT_GENERATION_PATH } /index-fixed` ;
7587 const result = await this . generateLogicalPrisma ( ) ;
@@ -230,11 +242,7 @@ export function enhance(prisma: any, context?: EnhancementContext<${authTypePara
230242 }
231243
232244 private get needsLogicalClient ( ) {
233- return this . hasDelegateModel ( this . model ) || this . hasAuthInDefault ( this . model ) ;
234- }
235-
236- private get needsPrismaClientTypeFixes ( ) {
237- return this . hasTypeDef ( this . model ) ;
245+ return this . hasDelegateModel ( this . model ) || this . hasAuthInDefault ( this . model ) || this . hasTypeDef ( this . model ) ;
238246 }
239247
240248 private hasDelegateModel ( model : Model ) {
@@ -449,11 +457,13 @@ export function enhance(prisma: any, context?: EnhancementContext<${authTypePara
449457 const auxFields = this . findAuxDecls ( variable ) ;
450458 if ( auxFields . length > 0 ) {
451459 structure . declarations . forEach ( ( variable ) => {
452- let source = variable . type ?. toString ( ) ;
453- auxFields . forEach ( ( f ) => {
454- source = source ?. replace ( f . getText ( ) , '' ) ;
455- } ) ;
456- variable . type = source ;
460+ if ( variable . type ) {
461+ let source = variable . type . toString ( ) ;
462+ auxFields . forEach ( ( f ) => {
463+ source = this . removeFromSource ( source , f . getText ( ) ) ;
464+ } ) ;
465+ variable . type = source ;
466+ }
457467 } ) ;
458468 }
459469
@@ -498,72 +508,16 @@ export function enhance(prisma: any, context?: EnhancementContext<${authTypePara
498508 // fix delegate payload union type
499509 source = this . fixDelegatePayloadType ( typeAlias , delegateInfo , source ) ;
500510
511+ // fix fk and relation fields related to using `auth()` in `@default`
512+ source = this . fixDefaultAuthType ( typeAlias , source ) ;
513+
501514 // fix json field type
502515 source = this . fixJsonFieldType ( typeAlias , source ) ;
503516
504517 structure . type = source ;
505518 return structure ;
506519 }
507520
508- private fixJsonFieldType ( typeAlias : TypeAliasDeclaration , source : string ) {
509- const modelsWithTypeField = this . model . declarations . filter (
510- ( d ) : d is DataModel => isDataModel ( d ) && d . fields . some ( ( f ) => isTypeDef ( f . type . reference ?. ref ) )
511- ) ;
512- const typeName = typeAlias . getName ( ) ;
513-
514- const getTypedJsonFields = ( model : DataModel ) => {
515- return model . fields . filter ( ( f ) => isTypeDef ( f . type . reference ?. ref ) ) ;
516- } ;
517-
518- const replacePrismaJson = ( source : string , field : DataModelField ) => {
519- return source . replace (
520- new RegExp ( `(${ field . name } \\??\\s*):[^\\n]+` ) ,
521- `$1: ${ field . type . reference ! . $refText } ${ field . type . array ? '[]' : '' } ${
522- field . type . optional ? ' | null' : ''
523- } `
524- ) ;
525- } ;
526-
527- // fix "$[Model]Payload" type
528- const payloadModelMatch = modelsWithTypeField . find ( ( m ) => `$${ m . name } Payload` === typeName ) ;
529- if ( payloadModelMatch ) {
530- const scalars = typeAlias
531- . getDescendantsOfKind ( SyntaxKind . PropertySignature )
532- . find ( ( p ) => p . getName ( ) === 'scalars' ) ;
533- if ( ! scalars ) {
534- return source ;
535- }
536-
537- const fieldsToFix = getTypedJsonFields ( payloadModelMatch ) ;
538- for ( const field of fieldsToFix ) {
539- source = replacePrismaJson ( source , field ) ;
540- }
541- }
542-
543- // fix input/output types, "[Model]CreateInput", etc.
544- const inputOutputModelMatch = modelsWithTypeField . find ( ( m ) => typeName . startsWith ( m . name ) ) ;
545- if ( inputOutputModelMatch ) {
546- const relevantTypePatterns = [
547- 'GroupByOutputType' ,
548- '(Unchecked)?Create(\\S+?)?Input' ,
549- '(Unchecked)?Update(\\S+?)?Input' ,
550- 'CreateManyInput' ,
551- '(Unchecked)?UpdateMany(Mutation)?Input' ,
552- ] ;
553- const typeRegex = modelsWithTypeField . map (
554- ( m ) => new RegExp ( `^(${ m . name } )(${ relevantTypePatterns . join ( '|' ) } )$` )
555- ) ;
556- if ( typeRegex . some ( ( r ) => r . test ( typeName ) ) ) {
557- const fieldsToFix = getTypedJsonFields ( inputOutputModelMatch ) ;
558- for ( const field of fieldsToFix ) {
559- source = replacePrismaJson ( source , field ) ;
560- }
561- }
562- }
563-
564- return source ;
565- }
566-
567521 private fixDelegatePayloadType ( typeAlias : TypeAliasDeclaration , delegateInfo : DelegateInfo , source : string ) {
568522 // change the type of `$<DelegateModel>Payload` type of delegate model to a union of concrete types
569523 const typeName = typeAlias . getName ( ) ;
@@ -595,7 +549,7 @@ export function enhance(prisma: any, context?: EnhancementContext<${authTypePara
595549 . getDescendantsOfKind ( SyntaxKind . PropertySignature )
596550 . filter ( ( p ) => [ 'create' , 'createMany' , 'connectOrCreate' , 'upsert' ] . includes ( p . getName ( ) ) ) ;
597551 toRemove . forEach ( ( r ) => {
598- source = source . replace ( r . getText ( ) , '' ) ;
552+ this . removeFromSource ( source , r . getText ( ) ) ;
599553 } ) ;
600554 }
601555 return source ;
@@ -633,7 +587,7 @@ export function enhance(prisma: any, context?: EnhancementContext<${authTypePara
633587 if ( isDiscriminatorField ( field ) ) {
634588 const fieldDef = this . findNamedProperty ( typeAlias , field . name ) ;
635589 if ( fieldDef ) {
636- source = source . replace ( fieldDef . getText ( ) , '' ) ;
590+ source = this . removeFromSource ( source , fieldDef . getText ( ) ) ;
637591 }
638592 }
639593 }
@@ -646,7 +600,7 @@ export function enhance(prisma: any, context?: EnhancementContext<${authTypePara
646600 const auxDecls = this . findAuxDecls ( typeAlias ) ;
647601 if ( auxDecls . length > 0 ) {
648602 auxDecls . forEach ( ( d ) => {
649- source = source . replace ( d . getText ( ) , '' ) ;
603+ source = this . removeFromSource ( source , d . getText ( ) ) ;
650604 } ) ;
651605 }
652606 return source ;
@@ -677,7 +631,7 @@ export function enhance(prisma: any, context?: EnhancementContext<${authTypePara
677631 const fieldDef = this . findNamedProperty ( typeAlias , relationFieldName ) ;
678632 if ( fieldDef ) {
679633 // remove relation field of delegate type, e.g., `asset`
680- source = source . replace ( fieldDef . getText ( ) , '' ) ;
634+ source = this . removeFromSource ( source , fieldDef . getText ( ) ) ;
681635 }
682636
683637 // remove fk fields related to the delegate type relation, e.g., `assetId`
@@ -709,13 +663,103 @@ export function enhance(prisma: any, context?: EnhancementContext<${authTypePara
709663 fkFields . forEach ( ( fkField ) => {
710664 const fieldDef = this . findNamedProperty ( typeAlias , fkField ) ;
711665 if ( fieldDef ) {
712- source = source . replace ( fieldDef . getText ( ) , '' ) ;
666+ source = this . removeFromSource ( source , fieldDef . getText ( ) ) ;
713667 }
714668 } ) ;
715669
716670 return source ;
717671 }
718672
673+ private fixDefaultAuthType ( typeAlias : TypeAliasDeclaration , source : string ) {
674+ const match = typeAlias . getName ( ) . match ( this . modelsWithAuthInDefaultCreateInputPattern ) ;
675+ if ( ! match ) {
676+ return source ;
677+ }
678+
679+ const modelName = match [ 1 ] ;
680+ const dataModel = this . model . declarations . find ( ( d ) : d is DataModel => isDataModel ( d ) && d . name === modelName ) ;
681+ if ( dataModel ) {
682+ for ( const fkField of dataModel . fields . filter ( ( f ) => f . attributes . some ( isDefaultWithAuth ) ) ) {
683+ // change fk field to optional since it has a default
684+ source = source . replace ( new RegExp ( `^(\\s*${ fkField . name } \\s*):` , 'm' ) , `$1?:` ) ;
685+
686+ const relationField = getRelationField ( fkField ) ;
687+ if ( relationField ) {
688+ // change relation field to optional since its fk has a default
689+ source = source . replace ( new RegExp ( `^(\\s*${ relationField . name } \\s*):` , 'm' ) , `$1?:` ) ;
690+ }
691+ }
692+ }
693+ return source ;
694+ }
695+
696+ private fixJsonFieldType ( typeAlias : TypeAliasDeclaration , source : string ) {
697+ const modelsWithTypeField = this . model . declarations . filter (
698+ ( d ) : d is DataModel => isDataModel ( d ) && d . fields . some ( ( f ) => isTypeDef ( f . type . reference ?. ref ) )
699+ ) ;
700+ const typeName = typeAlias . getName ( ) ;
701+
702+ const getTypedJsonFields = ( model : DataModel ) => {
703+ return model . fields . filter ( ( f ) => isTypeDef ( f . type . reference ?. ref ) ) ;
704+ } ;
705+
706+ const replacePrismaJson = ( source : string , field : DataModelField ) => {
707+ return source . replace (
708+ new RegExp ( `(${ field . name } \\??\\s*):[^\\n]+` ) ,
709+ `$1: ${ field . type . reference ! . $refText } ${ field . type . array ? '[]' : '' } ${
710+ field . type . optional ? ' | null' : ''
711+ } `
712+ ) ;
713+ } ;
714+
715+ // fix "$[Model]Payload" type
716+ const payloadModelMatch = modelsWithTypeField . find ( ( m ) => `$${ m . name } Payload` === typeName ) ;
717+ if ( payloadModelMatch ) {
718+ const scalars = typeAlias
719+ . getDescendantsOfKind ( SyntaxKind . PropertySignature )
720+ . find ( ( p ) => p . getName ( ) === 'scalars' ) ;
721+ if ( ! scalars ) {
722+ return source ;
723+ }
724+
725+ const fieldsToFix = getTypedJsonFields ( payloadModelMatch ) ;
726+ for ( const field of fieldsToFix ) {
727+ source = replacePrismaJson ( source , field ) ;
728+ }
729+ }
730+
731+ // fix input/output types, "[Model]CreateInput", etc.
732+ const inputOutputModelMatch = modelsWithTypeField . find ( ( m ) => typeName . startsWith ( m . name ) ) ;
733+ if ( inputOutputModelMatch ) {
734+ const relevantTypePatterns = [
735+ 'GroupByOutputType' ,
736+ '(Unchecked)?Create(\\S+?)?Input' ,
737+ '(Unchecked)?Update(\\S+?)?Input' ,
738+ 'CreateManyInput' ,
739+ '(Unchecked)?UpdateMany(Mutation)?Input' ,
740+ ] ;
741+ const typeRegex = modelsWithTypeField . map (
742+ ( m ) => new RegExp ( `^(${ m . name } )(${ relevantTypePatterns . join ( '|' ) } )$` )
743+ ) ;
744+ if ( typeRegex . some ( ( r ) => r . test ( typeName ) ) ) {
745+ const fieldsToFix = getTypedJsonFields ( inputOutputModelMatch ) ;
746+ for ( const field of fieldsToFix ) {
747+ source = replacePrismaJson ( source , field ) ;
748+ }
749+ }
750+ }
751+
752+ return source ;
753+ }
754+
755+ private async generateExtraTypes ( sf : SourceFile ) {
756+ for ( const decl of this . model . declarations ) {
757+ if ( isTypeDef ( decl ) ) {
758+ generateTypeDefType ( sf , decl ) ;
759+ }
760+ }
761+ }
762+
719763 private findNamedProperty ( typeAlias : TypeAliasDeclaration , name : string ) {
720764 return typeAlias . getFirstDescendant ( ( d ) => d . isKind ( SyntaxKind . PropertySignature ) && d . getName ( ) === name ) ;
721765 }
@@ -745,11 +789,12 @@ export function enhance(prisma: any, context?: EnhancementContext<${authTypePara
745789 return this . options . generatePermissionChecker === true ;
746790 }
747791
748- private async generateExtraTypes ( sf : SourceFile ) {
749- for ( const decl of this . model . declarations ) {
750- if ( isTypeDef ( decl ) ) {
751- generateTypeDefType ( sf , decl ) ;
752- }
753- }
792+ private removeFromSource ( source : string , text : string ) {
793+ source = source . replace ( text , '' ) ;
794+ return this . trimEmptyLines ( source ) ;
795+ }
796+
797+ private trimEmptyLines ( source : string ) : string {
798+ return source . replace ( / ^ \s * [ \r \n ] / gm, '' ) ;
754799 }
755800}
0 commit comments