1- import { getIdFields , isAuthInvocation , isDataModelFieldReference } from '@zenstackhq/sdk' ;
1+ import { getIdFields , getPrismaClientGenerator , isAuthInvocation , isDataModelFieldReference } from '@zenstackhq/sdk' ;
22import {
33 DataModel ,
44 DataModelField ,
55 Expression ,
66 isDataModel ,
77 isMemberAccessExpr ,
8+ isTypeDef ,
89 TypeDef ,
910 type Model ,
1011} from '@zenstackhq/sdk/ast' ;
@@ -19,27 +20,36 @@ export function generateAuthType(model: Model, authDecl: DataModel | TypeDef) {
1920 const types = new Map <
2021 string ,
2122 {
23+ isTypeDef : boolean ;
2224 // relation fields to require
2325 requiredRelations : { name : string ; type : string } [ ] ;
2426 }
2527 > ( ) ;
2628
27- types . set ( authDecl . name , { requiredRelations : [ ] } ) ;
29+ types . set ( authDecl . name , { isTypeDef : isTypeDef ( authDecl ) , requiredRelations : [ ] } ) ;
2830
29- const ensureType = ( model : string ) => {
30- if ( ! types . has ( model ) ) {
31- types . set ( model , { requiredRelations : [ ] } ) ;
31+ const findType = ( name : string ) =>
32+ model . declarations . find ( ( d ) => ( isDataModel ( d ) || isTypeDef ( d ) ) && d . name === name ) ;
33+
34+ const ensureType = ( name : string ) => {
35+ if ( ! types . has ( name ) ) {
36+ const decl = findType ( name ) ;
37+ if ( ! decl ) {
38+ return ;
39+ }
40+ types . set ( name , { isTypeDef : isTypeDef ( decl ) , requiredRelations : [ ] } ) ;
3241 }
3342 } ;
3443
35- const addAddField = ( model : string , name : string , type : string , array : boolean ) => {
36- let fields = types . get ( model ) ;
37- if ( ! fields ) {
38- fields = { requiredRelations : [ ] } ;
39- types . set ( model , fields ) ;
44+ const addTypeField = ( typeName : string , fieldName : string , fieldType : string , array : boolean ) => {
45+ let typeInfo = types . get ( typeName ) ;
46+ if ( ! typeInfo ) {
47+ const decl = findType ( typeName ) ;
48+ typeInfo = { isTypeDef : isTypeDef ( decl ) , requiredRelations : [ ] } ;
49+ types . set ( typeName , typeInfo ) ;
4050 }
41- if ( ! fields . requiredRelations . find ( ( f ) => f . name === name ) ) {
42- fields . requiredRelations . push ( { name, type : array ? `${ type } []` : type } ) ;
51+ if ( ! typeInfo . requiredRelations . find ( ( f ) => f . name === fieldName ) ) {
52+ typeInfo . requiredRelations . push ( { name : fieldName , type : array ? `${ fieldType } []` : fieldType } ) ;
4353 }
4454 } ;
4555
@@ -57,7 +67,7 @@ export function generateAuthType(model: Model, authDecl: DataModel | TypeDef) {
5767 // member is a relation
5868 const fieldType = memberDecl . type . reference . ref . name ;
5969 ensureType ( fieldType ) ;
60- addAddField ( exprType . name , memberDecl . name , fieldType , memberDecl . type . array ) ;
70+ addTypeField ( exprType . name , memberDecl . name , fieldType , memberDecl . type . array ) ;
6171 }
6272 }
6373 }
@@ -69,12 +79,15 @@ export function generateAuthType(model: Model, authDecl: DataModel | TypeDef) {
6979 if ( isDataModel ( fieldType ) ) {
7080 // field is a relation
7181 ensureType ( fieldType . name ) ;
72- addAddField ( fieldDecl . $container . name , node . target . $refText , fieldType . name , fieldDecl . type . array ) ;
82+ addTypeField ( fieldDecl . $container . name , node . target . $refText , fieldType . name , fieldDecl . type . array ) ;
7383 }
7484 }
7585 } ) ;
7686 } ) ;
7787
88+ const prismaGenerator = getPrismaClientGenerator ( model ) ;
89+ const isNewGenerator = ! ! prismaGenerator ?. isNewGenerator ;
90+
7891 // generate:
7992 // `
8093 // namespace auth {
@@ -86,25 +99,27 @@ export function generateAuthType(model: Model, authDecl: DataModel | TypeDef) {
8699 return `export namespace auth {
87100 type WithRequired<T, K extends keyof T> = T & { [P in K]-?: T[P] };
88101${ Array . from ( types . entries ( ) )
89- . map ( ( [ model , fields ] ) => {
90- let result = `Partial<_P.${ model } >` ;
102+ . map ( ( [ type , typeInfo ] ) => {
103+ // TypeDef types are generated in "json-types.ts" for the new "prisma-client" generator
104+ const typeRef = isNewGenerator ? `$TypeDefs.${ type } ` : `_P.${ type } ` ;
105+ let result = `Partial<${ typeRef } >` ;
91106
92- if ( model === authDecl . name ) {
107+ if ( type === authDecl . name ) {
93108 // auth model's id fields are always required
94109 const idFields = getIdFields ( authDecl ) . map ( ( f ) => f . name ) ;
95110 if ( idFields . length > 0 ) {
96111 result = `WithRequired<${ result } , ${ idFields . map ( ( f ) => `'${ f } '` ) . join ( '|' ) } >` ;
97112 }
98113 }
99114
100- if ( fields . requiredRelations . length > 0 ) {
115+ if ( typeInfo . requiredRelations . length > 0 ) {
101116 // merge required relation fields
102- result = `${ result } & { ${ fields . requiredRelations . map ( ( f ) => `${ f . name } : ${ f . type } ` ) . join ( '; ' ) } }` ;
117+ result = `${ result } & { ${ typeInfo . requiredRelations . map ( ( f ) => `${ f . name } : ${ f . type } ` ) . join ( '; ' ) } }` ;
103118 }
104119
105120 result = `${ result } & Record<string, unknown>` ;
106121
107- return ` export type ${ model } = ${ result } ;` ;
122+ return ` export type ${ type } = ${ result } ;` ;
108123 } )
109124 . join ( '\n' ) }
110125}` ;
0 commit comments