11import  { 
2+     ExpressionContext , 
23    PluginError , 
34    PluginGlobalOptions , 
45    PluginOptions , 
56    RUNTIME_PACKAGE , 
7+     TypeScriptExpressionTransformer , 
8+     TypeScriptExpressionTransformerError , 
69    ensureEmptyDir , 
10+     getAttributeArg , 
11+     getAttributeArgLiteral , 
712    getDataModels , 
13+     getLiteralArray , 
814    hasAttribute , 
15+     isDataModelFieldReference , 
916    isDiscriminatorField , 
1017    isEnumFieldReference , 
1118    isForeignKeyField , 
@@ -15,7 +22,7 @@ import {
1522    resolvePath , 
1623    saveSourceFile , 
1724}  from  '@zenstackhq/sdk' ; 
18- import  {  DataModel ,  EnumField ,  Model ,  TypeDef ,  isDataModel ,  isEnum ,  isTypeDef  }  from  '@zenstackhq/sdk/ast' ; 
25+ import  {  DataModel ,  EnumField ,  Model ,  TypeDef ,  isArrayExpr ,   isDataModel ,  isEnum ,  isTypeDef  }  from  '@zenstackhq/sdk/ast' ; 
1926import  {  addMissingInputObjectTypes ,  resolveAggregateOperationSupport  }  from  '@zenstackhq/sdk/dmmf-helpers' ; 
2027import  {  getPrismaClientImportSpec ,  supportCreateMany ,  type  DMMF  }  from  '@zenstackhq/sdk/prisma' ; 
2128import  {  streamAllContents  }  from  'langium' ; 
@@ -26,7 +33,7 @@ import { name } from '.';
2633import  {  getDefaultOutputFolder  }  from  '../plugin-utils' ; 
2734import  Transformer  from  './transformer' ; 
2835import  {  ObjectMode  }  from  './types' ; 
29- import  {  makeFieldSchema ,   makeValidationRefinements  }  from  './utils/schema-gen' ; 
36+ import  {  makeFieldSchema  }  from  './utils/schema-gen' ; 
3037
3138export  class  ZodSchemaGenerator  { 
3239    private  readonly  sourceFiles : SourceFile [ ]  =  [ ] ; 
@@ -294,7 +301,7 @@ export class ZodSchemaGenerator {
294301        sf . replaceWithText ( ( writer )  =>  { 
295302            this . addPreludeAndImports ( typeDef ,  writer ,  output ) ; 
296303
297-             writer . write ( `export  const ${ typeDef . name } Schema  = z.object(` ) ; 
304+             writer . write ( `const baseSchema  = z.object(` ) ; 
298305            writer . inlineBlock ( ( )  =>  { 
299306                typeDef . fields . forEach ( ( field )  =>  { 
300307                    writer . writeLine ( `${ field . name } ${ makeFieldSchema ( field ) }  ) ; 
@@ -313,9 +320,24 @@ export class ZodSchemaGenerator {
313320                    writer . writeLine ( ').strict();' ) ; 
314321                    break ; 
315322            } 
316-         } ) ; 
317323
318-         // TODO: "@@validate" refinements 
324+             // compile "@@validate" to a function calling zod's `.refine()` 
325+             const  refineFuncName  =  this . createRefineFunction ( typeDef ,  writer ) ; 
326+ 
327+             if  ( refineFuncName )  { 
328+                 // export a schema without refinement for extensibility: `[Model]WithoutRefineSchema` 
329+                 const  noRefineSchema  =  `${ upperCaseFirst ( typeDef . name ) }  ; 
330+                 writer . writeLine ( ` 
331+ /** 
332+  * \`${ typeDef . name }  
333+  */ 
334+ export const ${ noRefineSchema }  
335+ export const ${ typeDef . name } ${ refineFuncName } ${ noRefineSchema }  
336+ ` ) ; 
337+             }  else  { 
338+                 writer . writeLine ( `export const ${ typeDef . name }  ) ; 
339+             } 
340+         } ) ; 
319341
320342        return  schemaName ; 
321343    } 
@@ -436,22 +458,7 @@ export class ZodSchemaGenerator {
436458            } 
437459
438460            // compile "@@validate" to ".refine" 
439-             const  refinements  =  makeValidationRefinements ( model ) ; 
440-             let  refineFuncName : string  |  undefined ; 
441-             if  ( refinements . length  >  0 )  { 
442-                 refineFuncName  =  `refine${ upperCaseFirst ( model . name ) }  ; 
443-                 writer . writeLine ( 
444-                     ` 
445- /** 
446-  * Schema refinement function for applying \`@@validate\` rules. 
447-  */ 
448- export function ${ refineFuncName } ${ refinements . join (  
449-                         '\n'  
450-                     ) }  ;
451- } 
452- ` 
453-                 ) ; 
454-             } 
461+             const  refineFuncName  =  this . createRefineFunction ( model ,  writer ) ; 
455462
456463            // delegate discriminator fields are to be excluded from mutation schemas 
457464            const  delegateDiscriminatorFields  =  model . fields . filter ( ( field )  =>  isDiscriminatorField ( field ) ) ; 
@@ -658,6 +665,74 @@ export const ${upperCaseFirst(model.name)}UpdateSchema = ${updateSchema};
658665        return  schemaName ; 
659666    } 
660667
668+     private  createRefineFunction ( decl : DataModel  |  TypeDef ,  writer : CodeBlockWriter )  { 
669+         const  refinements  =  this . makeValidationRefinements ( decl ) ; 
670+         let  refineFuncName : string  |  undefined ; 
671+         if  ( refinements . length  >  0 )  { 
672+             refineFuncName  =  `refine${ upperCaseFirst ( decl . name ) }  ; 
673+             writer . writeLine ( 
674+                 ` 
675+     /** 
676+     * Schema refinement function for applying \`@@validate\` rules. 
677+     */ 
678+     export function ${ refineFuncName } ${ refinements . join (  
679+                     '\n'  
680+                 ) }  ;
681+     } 
682+     ` 
683+             ) ; 
684+             return  refineFuncName ; 
685+         }  else  { 
686+             return  undefined ; 
687+         } 
688+     } 
689+ 
690+     private  makeValidationRefinements ( decl : DataModel  |  TypeDef )  { 
691+         const  attrs  =  decl . attributes . filter ( ( attr )  =>  attr . decl . ref ?. name  ===  '@@validate' ) ; 
692+         const  refinements  =  attrs 
693+             . map ( ( attr )  =>  { 
694+                 const  valueArg  =  getAttributeArg ( attr ,  'value' ) ; 
695+                 if  ( ! valueArg )  { 
696+                     return  undefined ; 
697+                 } 
698+ 
699+                 const  messageArg  =  getAttributeArgLiteral < string > ( attr ,  'message' ) ; 
700+                 const  message  =  messageArg  ? `message: ${ JSON . stringify ( messageArg ) }   : '' ; 
701+ 
702+                 const  pathArg  =  getAttributeArg ( attr ,  'path' ) ; 
703+                 const  path  = 
704+                     pathArg  &&  isArrayExpr ( pathArg ) 
705+                         ? `path: ['${ getLiteralArray < string > ( pathArg ) ?. join ( `', '` ) }  
706+                         : '' ; 
707+ 
708+                 const  options  =  `, { ${ message } ${ path }  ; 
709+ 
710+                 try  { 
711+                     let  expr  =  new  TypeScriptExpressionTransformer ( { 
712+                         context : ExpressionContext . ValidationRule , 
713+                         fieldReferenceContext : 'value' , 
714+                     } ) . transform ( valueArg ) ; 
715+ 
716+                     if  ( isDataModelFieldReference ( valueArg ) )  { 
717+                         // if the expression is a simple field reference, treat undefined 
718+                         // as true since the all fields are optional in validation context 
719+                         expr  =  `${ expr }  ; 
720+                     } 
721+ 
722+                     return  `.refine((value: any) => ${ expr } ${ options }  ; 
723+                 }  catch  ( err )  { 
724+                     if  ( err  instanceof  TypeScriptExpressionTransformerError )  { 
725+                         throw  new  PluginError ( name ,  err . message ) ; 
726+                     }  else  { 
727+                         throw  err ; 
728+                     } 
729+                 } 
730+             } ) 
731+             . filter ( ( r )  =>  ! ! r ) ; 
732+ 
733+         return  refinements ; 
734+     } 
735+ 
661736    private  makePartial ( schema : string ,  fields ?: string [ ] )  { 
662737        if  ( fields )  { 
663738            if  ( fields . length  ===  0 )  { 
0 commit comments