44 PluginOptions ,
55 createProject ,
66 emitProject ,
7+ getAttribute ,
8+ getAttributeArg ,
79 getDataModels ,
810 getLiteral ,
911 getPrismaClientImportSpec ,
@@ -15,7 +17,16 @@ import {
1517 resolvePath ,
1618 saveProject ,
1719} from '@zenstackhq/sdk' ;
18- import { DataModel , DataSource , EnumField , Model , isDataModel , isDataSource , isEnum } from '@zenstackhq/sdk/ast' ;
20+ import {
21+ DataModel ,
22+ DataModelField ,
23+ DataSource ,
24+ EnumField ,
25+ Model ,
26+ isDataModel ,
27+ isDataSource ,
28+ isEnum ,
29+ } from '@zenstackhq/sdk/ast' ;
1930import { addMissingInputObjectTypes , resolveAggregateOperationSupport } from '@zenstackhq/sdk/dmmf-helpers' ;
2031import { promises as fs } from 'fs' ;
2132import { streamAllContents } from 'langium' ;
@@ -262,10 +273,17 @@ async function generateModelSchema(model: DataModel, project: Project, output: s
262273 sf . replaceWithText ( ( writer ) => {
263274 const fields = model . fields . filter (
264275 ( field ) =>
265- // scalar fields only
276+ // regular fields only
266277 ! isDataModel ( field . type . reference ?. ref ) && ! isForeignKeyField ( field )
267278 ) ;
268279
280+ const relations = model . fields . filter ( ( field ) => isDataModel ( field . type . reference ?. ref ) ) ;
281+ const fkFields = model . fields . filter ( ( field ) => isForeignKeyField ( field ) ) ;
282+ // unsafe version of relations: including foreign keys and relation fields without fk
283+ const unsafeRelations = model . fields . filter (
284+ ( field ) => isForeignKeyField ( field ) || ( isDataModel ( field . type . reference ?. ref ) && ! hasForeignKey ( field ) )
285+ ) ;
286+
269287 writer . writeLine ( '/* eslint-disable */' ) ;
270288 writer . writeLine ( `import { z } from 'zod';` ) ;
271289
@@ -302,7 +320,7 @@ async function generateModelSchema(model: DataModel, project: Project, output: s
302320 writer . writeLine ( `import { Decimal } from 'decimal.js';` ) ;
303321 }
304322
305- // create base schema
323+ // base schema
306324 writer . write ( `const baseSchema = z.object(` ) ;
307325 writer . inlineBlock ( ( ) => {
308326 fields . forEach ( ( field ) => {
@@ -311,31 +329,92 @@ async function generateModelSchema(model: DataModel, project: Project, output: s
311329 } ) ;
312330 writer . writeLine ( ');' ) ;
313331
332+ // relation fields
333+
334+ let allRelationSchema : string | undefined ;
335+ let safeRelationSchema : string | undefined ;
336+ let unsafeRelationSchema : string | undefined ;
337+
338+ if ( relations . length > 0 || fkFields . length > 0 ) {
339+ allRelationSchema = 'allRelationSchema' ;
340+ writer . write ( `const ${ allRelationSchema } = z.object(` ) ;
341+ writer . inlineBlock ( ( ) => {
342+ [ ...relations , ...fkFields ] . forEach ( ( field ) => {
343+ writer . writeLine ( `${ field . name } : ${ makeFieldSchema ( field ) } ,` ) ;
344+ } ) ;
345+ } ) ;
346+ writer . writeLine ( ');' ) ;
347+ }
348+
349+ if ( relations . length > 0 ) {
350+ safeRelationSchema = 'safeRelationSchema' ;
351+ writer . write ( `const ${ safeRelationSchema } = z.object(` ) ;
352+ writer . inlineBlock ( ( ) => {
353+ relations . forEach ( ( field ) => {
354+ writer . writeLine ( `${ field . name } : ${ makeFieldSchema ( field , true ) } ,` ) ;
355+ } ) ;
356+ } ) ;
357+ writer . writeLine ( ');' ) ;
358+ }
359+
360+ if ( unsafeRelations . length > 0 ) {
361+ unsafeRelationSchema = 'unsafeRelationSchema' ;
362+ writer . write ( `const ${ unsafeRelationSchema } = z.object(` ) ;
363+ writer . inlineBlock ( ( ) => {
364+ unsafeRelations . forEach ( ( field ) => {
365+ writer . writeLine ( `${ field . name } : ${ makeFieldSchema ( field , true ) } ,` ) ;
366+ } ) ;
367+ } ) ;
368+ writer . writeLine ( ');' ) ;
369+ }
370+
314371 // compile "@@validate" to ".refine"
315372 const refinements = makeValidationRefinements ( model ) ;
373+ let refineFuncName : string | undefined ;
316374 if ( refinements . length > 0 ) {
375+ refineFuncName = `refine${ upperCaseFirst ( model . name ) } ` ;
317376 writer . writeLine (
318- `function refine <T, D extends z.ZodTypeDef>(schema: z.ZodType<T, D, T>) { return schema${ refinements . join (
377+ `export function ${ refineFuncName } <T, D extends z.ZodTypeDef>(schema: z.ZodType<T, D, T>) { return schema${ refinements . join (
319378 '\n'
320379 ) } ; }`
321380 ) ;
322381 }
323382
324- // model schema
383+ ////////////////////////////////////////////////
384+ // 1. Model schema
385+ ////////////////////////////////////////////////
325386 let modelSchema = 'baseSchema' ;
387+
388+ // omit fields
326389 const fieldsToOmit = fields . filter ( ( field ) => hasAttribute ( field , '@omit' ) ) ;
327390 if ( fieldsToOmit . length > 0 ) {
328391 modelSchema = makeOmit (
329392 modelSchema ,
330393 fieldsToOmit . map ( ( f ) => f . name )
331394 ) ;
332395 }
333- if ( refinements . length > 0 ) {
334- modelSchema = `refine(${ modelSchema } )` ;
396+
397+ if ( allRelationSchema ) {
398+ // export schema with only scalar fields
399+ const modelScalarSchema = `${ upperCaseFirst ( model . name ) } ScalarSchema` ;
400+ writer . writeLine ( `export const ${ modelScalarSchema } = ${ modelSchema } ;` ) ;
401+ modelSchema = modelScalarSchema ;
402+
403+ // merge relations
404+ modelSchema = makeMerge ( modelSchema , allRelationSchema ) ;
405+ }
406+
407+ // refine
408+ if ( refineFuncName ) {
409+ const noRefineSchema = `${ upperCaseFirst ( model . name ) } WithoutRefineSchema` ;
410+ writer . writeLine ( `export const ${ noRefineSchema } = ${ modelSchema } ;` ) ;
411+ modelSchema = `${ refineFuncName } (${ noRefineSchema } )` ;
335412 }
336413 writer . writeLine ( `export const ${ upperCaseFirst ( model . name ) } Schema = ${ modelSchema } ;` ) ;
337414
338- // create schema
415+ ////////////////////////////////////////////////
416+ // 2. Create schema
417+ ////////////////////////////////////////////////
339418 let createSchema = 'baseSchema' ;
340419 const fieldsWithDefault = fields . filter (
341420 ( field ) => hasAttribute ( field , '@default' ) || hasAttribute ( field , '@updatedAt' ) || field . type . array
@@ -346,29 +425,104 @@ async function generateModelSchema(model: DataModel, project: Project, output: s
346425 fieldsWithDefault . map ( ( f ) => f . name )
347426 ) ;
348427 }
349- if ( refinements . length > 0 ) {
350- createSchema = `refine(${ createSchema } )` ;
428+
429+ if ( safeRelationSchema || unsafeRelationSchema ) {
430+ // export schema with only scalar fields
431+ const createScalarSchema = `${ upperCaseFirst ( model . name ) } CreateScalarSchema` ;
432+ writer . writeLine ( `export const ${ createScalarSchema } = ${ createSchema } ;` ) ;
433+ createSchema = createScalarSchema ;
434+
435+ if ( safeRelationSchema && unsafeRelationSchema ) {
436+ // build a union of with relation object fields and with fk fields (mutually exclusive)
437+
438+ // TODO: we make all relation fields partial for now because in case of
439+ // nested create, not all relation/fk fields are inside payload, need a
440+ // better solution
441+ createSchema = makeUnion (
442+ makeMerge ( createSchema , makePartial ( safeRelationSchema ) ) ,
443+ makeMerge ( createSchema , makePartial ( unsafeRelationSchema ) )
444+ ) ;
445+ } else if ( safeRelationSchema ) {
446+ // just relation
447+
448+ // TODO: we make all relation fields partial for now because in case of
449+ // nested create, not all relation/fk fields are inside payload, need a
450+ // better solution
451+ createSchema = makeMerge ( createSchema , makePartial ( safeRelationSchema ) ) ;
452+ }
453+ }
454+
455+ if ( refineFuncName ) {
456+ // export a schema without refinement for extensibility
457+ const noRefineSchema = `${ upperCaseFirst ( model . name ) } CreateWithoutRefineSchema` ;
458+ writer . writeLine ( `export const ${ noRefineSchema } = ${ createSchema } ;` ) ;
459+ createSchema = `${ refineFuncName } (${ noRefineSchema } )` ;
351460 }
352461 writer . writeLine ( `export const ${ upperCaseFirst ( model . name ) } CreateSchema = ${ createSchema } ;` ) ;
353462
354- // update schema
355- let updateSchema = 'baseSchema.partial()' ;
356- if ( refinements . length > 0 ) {
357- updateSchema = `refine(${ updateSchema } )` ;
463+ ////////////////////////////////////////////////
464+ // 3. Update schema
465+ ////////////////////////////////////////////////
466+ let updateSchema = makePartial ( 'baseSchema' ) ;
467+
468+ if ( safeRelationSchema || unsafeRelationSchema ) {
469+ // export schema with only scalar fields
470+ const updateScalarSchema = `${ upperCaseFirst ( model . name ) } UpdateScalarSchema` ;
471+ writer . writeLine ( `export const ${ updateScalarSchema } = ${ updateSchema } ;` ) ;
472+ updateSchema = updateScalarSchema ;
473+
474+ if ( safeRelationSchema && unsafeRelationSchema ) {
475+ // build a union of with relation object fields and with fk fields (mutually exclusive)
476+ updateSchema = makeUnion (
477+ makeMerge ( updateSchema , makePartial ( safeRelationSchema ) ) ,
478+ makeMerge ( updateSchema , makePartial ( unsafeRelationSchema ) )
479+ ) ;
480+ } else if ( safeRelationSchema ) {
481+ // just relation
482+ updateSchema = makeMerge ( updateSchema , makePartial ( safeRelationSchema ) ) ;
483+ }
484+ }
485+
486+ if ( refineFuncName ) {
487+ // export a schema without refinement for extensibility
488+ const noRefineSchema = `${ upperCaseFirst ( model . name ) } UpdateWithoutRefineSchema` ;
489+ writer . writeLine ( `export const ${ noRefineSchema } = ${ updateSchema } ;` ) ;
490+ updateSchema = `${ refineFuncName } (${ noRefineSchema } )` ;
358491 }
359492 writer . writeLine ( `export const ${ upperCaseFirst ( model . name ) } UpdateSchema = ${ updateSchema } ;` ) ;
360493 } ) ;
494+
361495 return schemaName ;
362496}
363497
364- function makePartial ( schema : string , fields : string [ ] ) {
365- return `${ schema } .partial({
498+ function makePartial ( schema : string , fields ?: string [ ] ) {
499+ if ( fields ) {
500+ return `${ schema } .partial({
366501 ${ fields . map ( ( f ) => `${ f } : true` ) . join ( ', ' ) } ,
367502 })` ;
503+ } else {
504+ return `${ schema } .partial()` ;
505+ }
368506}
369507
370508function makeOmit ( schema : string , fields : string [ ] ) {
371509 return `${ schema } .omit({
372510 ${ fields . map ( ( f ) => `${ f } : true` ) . join ( ', ' ) } ,
373511 })` ;
374512}
513+
514+ function makeMerge ( schema1 : string , schema2 : string ) : string {
515+ return `${ schema1 } .merge(${ schema2 } )` ;
516+ }
517+
518+ function makeUnion ( ...schemas : string [ ] ) : string {
519+ return `z.union([${ schemas . join ( ', ' ) } ])` ;
520+ }
521+
522+ function hasForeignKey ( field : DataModelField ) {
523+ const relAttr = getAttribute ( field , '@relation' ) ;
524+ if ( ! relAttr ) {
525+ return false ;
526+ }
527+ return ! ! getAttributeArg ( relAttr , 'fields' ) ;
528+ }
0 commit comments