@@ -7,24 +7,28 @@ import {
77 emitProject ,
88 getDataModels ,
99 getLiteral ,
10+ getPrismaClientImportSpec ,
1011 hasAttribute ,
12+ isEnumFieldReference ,
1113 isForeignKeyField ,
1214 resolvePath ,
1315 saveProject ,
1416} from '@zenstackhq/sdk' ;
15- import { DataModel , DataSource , Model , isDataModel , isDataSource , isEnum } from '@zenstackhq/sdk/ast' ;
17+ import { DataModel , DataSource , EnumField , Model , isDataModel , isDataSource , isEnum } from '@zenstackhq/sdk/ast' ;
1618import {
1719 AggregateOperationSupport ,
1820 addMissingInputObjectTypes ,
1921 resolveAggregateOperationSupport ,
2022} from '@zenstackhq/sdk/dmmf-helpers' ;
2123import { promises as fs } from 'fs' ;
24+ import { streamAllContents } from 'langium' ;
2225import path from 'path' ;
2326import { Project } from 'ts-morph' ;
27+ import { upperCaseFirst } from 'upper-case-first' ;
28+ import { isFromStdlib } from '../../language-server/utils' ;
2429import { getDefaultOutputFolder } from '../plugin-utils' ;
2530import Transformer from './transformer' ;
2631import removeDir from './utils/removeDir' ;
27- import { upperCaseFirst } from 'upper-case-first' ;
2832import { makeFieldSchema , makeValidationRefinements } from './utils/schema-gen' ;
2933
3034export async function generate ( model : Model , options : PluginOptions , dmmf : DMMF . Document ) {
@@ -176,8 +180,6 @@ async function generateModelSchema(model: DataModel, project: Project, output: s
176180 overwrite : true ,
177181 } ) ;
178182 sf . replaceWithText ( ( writer ) => {
179- writer . writeLine ( '/* eslint-disable */' ) ;
180-
181183 const fields = model . fields . filter (
182184 ( field ) =>
183185 ! AUXILIARY_FIELDS . includes ( field . name ) &&
@@ -186,9 +188,25 @@ async function generateModelSchema(model: DataModel, project: Project, output: s
186188 ! isForeignKeyField ( field )
187189 ) ;
188190
191+ writer . writeLine ( '/* eslint-disable */' ) ;
189192 writer . writeLine ( `import { z } from 'zod';` ) ;
190193
191- // import enums
194+ // import user-defined enums from Prisma as they might be referenced in the expressions
195+ const importEnums = new Set < string > ( ) ;
196+ for ( const node of streamAllContents ( model ) ) {
197+ if ( isEnumFieldReference ( node ) ) {
198+ const field = node . target . ref as EnumField ;
199+ if ( ! isFromStdlib ( field . $container ) ) {
200+ importEnums . add ( field . $container . name ) ;
201+ }
202+ }
203+ }
204+ if ( importEnums . size > 0 ) {
205+ const prismaImport = getPrismaClientImportSpec ( model . $container , path . join ( output , 'models' ) ) ;
206+ writer . writeLine ( `import { ${ [ ...importEnums ] . join ( ', ' ) } } from '${ prismaImport } ';` ) ;
207+ }
208+
209+ // import enum schemas
192210 for ( const field of fields ) {
193211 if ( field . type . reference ?. ref && isEnum ( field . type . reference ?. ref ) ) {
194212 const name = upperCaseFirst ( field . type . reference ?. ref . name ) ;
@@ -205,9 +223,9 @@ async function generateModelSchema(model: DataModel, project: Project, output: s
205223 } ) ;
206224 writer . writeLine ( ');' ) ;
207225
226+ // compile "@@validate" to ".refine"
208227 const refinements = makeValidationRefinements ( model ) ;
209228 if ( refinements . length > 0 ) {
210- console . log ( 'Generated refinements:' , refinements ) ;
211229 writer . writeLine ( `function refine(schema: z.ZodType) { return schema${ refinements . join ( '\n' ) } ; }` ) ;
212230 }
213231
0 commit comments