44 PluginOptions ,
55 createProject ,
66 emitProject ,
7- getAttribute ,
8- getAttributeArg ,
97 getDataModels ,
108 getLiteral ,
119 getPrismaClientImportSpec ,
@@ -17,16 +15,7 @@ import {
1715 resolvePath ,
1816 saveProject ,
1917} from '@zenstackhq/sdk' ;
20- import {
21- DataModel ,
22- DataModelField ,
23- DataSource ,
24- EnumField ,
25- Model ,
26- isDataModel ,
27- isDataSource ,
28- isEnum ,
29- } from '@zenstackhq/sdk/ast' ;
18+ import { DataModel , DataSource , EnumField , Model , isDataModel , isDataSource , isEnum } from '@zenstackhq/sdk/ast' ;
3019import { addMissingInputObjectTypes , resolveAggregateOperationSupport } from '@zenstackhq/sdk/dmmf-helpers' ;
3120import { promises as fs } from 'fs' ;
3221import { streamAllContents } from 'langium' ;
@@ -271,18 +260,14 @@ async function generateModelSchema(model: DataModel, project: Project, output: s
271260 overwrite : true ,
272261 } ) ;
273262 sf . replaceWithText ( ( writer ) => {
274- const fields = model . fields . filter (
263+ const scalarFields = model . fields . filter (
275264 ( field ) =>
276265 // regular fields only
277266 ! isDataModel ( field . type . reference ?. ref ) && ! isForeignKeyField ( field )
278267 ) ;
279268
280269 const relations = model . fields . filter ( ( field ) => isDataModel ( field . type . reference ?. ref ) ) ;
281270 const fkFields = model . fields . filter ( ( field ) => isForeignKeyField ( field ) ) ;
282- // unsafe version of relations: including foreign keys and relation fields without fk
283- const unsafeRelations = model . fields . filter (
284- ( field ) => isForeignKeyField ( field ) || ( isDataModel ( field . type . reference ?. ref ) && ! hasForeignKey ( field ) )
285- ) ;
286271
287272 writer . writeLine ( '/* eslint-disable */' ) ;
288273 writer . writeLine ( `import { z } from 'zod';` ) ;
@@ -304,7 +289,7 @@ async function generateModelSchema(model: DataModel, project: Project, output: s
304289
305290 // import enum schemas
306291 const importedEnumSchemas = new Set < string > ( ) ;
307- for ( const field of fields ) {
292+ for ( const field of scalarFields ) {
308293 if ( field . type . reference ?. ref && isEnum ( field . type . reference ?. ref ) ) {
309294 const name = upperCaseFirst ( field . type . reference ?. ref . name ) ;
310295 if ( ! importedEnumSchemas . has ( name ) ) {
@@ -315,29 +300,28 @@ async function generateModelSchema(model: DataModel, project: Project, output: s
315300 }
316301
317302 // import Decimal
318- if ( fields . some ( ( field ) => field . type . type === 'Decimal' ) ) {
303+ if ( scalarFields . some ( ( field ) => field . type . type === 'Decimal' ) ) {
319304 writer . writeLine ( `import { DecimalSchema } from '../common';` ) ;
320305 writer . writeLine ( `import { Decimal } from 'decimal.js';` ) ;
321306 }
322307
323308 // base schema
324309 writer . write ( `const baseSchema = z.object(` ) ;
325310 writer . inlineBlock ( ( ) => {
326- fields . forEach ( ( field ) => {
311+ scalarFields . forEach ( ( field ) => {
327312 writer . writeLine ( `${ field . name } : ${ makeFieldSchema ( field ) } ,` ) ;
328313 } ) ;
329314 } ) ;
330315 writer . writeLine ( ');' ) ;
331316
332317 // relation fields
333318
334- let allRelationSchema : string | undefined ;
335- let safeRelationSchema : string | undefined ;
336- let unsafeRelationSchema : string | undefined ;
319+ let relationSchema : string | undefined ;
320+ let fkSchema : string | undefined ;
337321
338322 if ( relations . length > 0 || fkFields . length > 0 ) {
339- allRelationSchema = 'allRelationSchema ' ;
340- writer . write ( `const ${ allRelationSchema } = z.object(` ) ;
323+ relationSchema = 'relationSchema ' ;
324+ writer . write ( `const ${ relationSchema } = z.object(` ) ;
341325 writer . inlineBlock ( ( ) => {
342326 [ ...relations , ...fkFields ] . forEach ( ( field ) => {
343327 writer . writeLine ( `${ field . name } : ${ makeFieldSchema ( field ) } ,` ) ;
@@ -346,23 +330,12 @@ async function generateModelSchema(model: DataModel, project: Project, output: s
346330 writer . writeLine ( ');' ) ;
347331 }
348332
349- if ( relations . length > 0 ) {
350- safeRelationSchema = 'safeRelationSchema ' ;
351- writer . write ( `const ${ safeRelationSchema } = z.object(` ) ;
333+ if ( fkFields . length > 0 ) {
334+ fkSchema = 'fkSchema ' ;
335+ writer . write ( `const ${ fkSchema } = z.object(` ) ;
352336 writer . inlineBlock ( ( ) => {
353- relations . forEach ( ( field ) => {
354- writer . writeLine ( `${ field . name } : ${ makeFieldSchema ( field , true ) } ,` ) ;
355- } ) ;
356- } ) ;
357- writer . writeLine ( ');' ) ;
358- }
359-
360- if ( unsafeRelations . length > 0 ) {
361- unsafeRelationSchema = 'unsafeRelationSchema' ;
362- writer . write ( `const ${ unsafeRelationSchema } = z.object(` ) ;
363- writer . inlineBlock ( ( ) => {
364- unsafeRelations . forEach ( ( field ) => {
365- writer . writeLine ( `${ field . name } : ${ makeFieldSchema ( field , true ) } ,` ) ;
337+ fkFields . forEach ( ( field ) => {
338+ writer . writeLine ( `${ field . name } : ${ makeFieldSchema ( field ) } ,` ) ;
366339 } ) ;
367340 } ) ;
368341 writer . writeLine ( ');' ) ;
@@ -383,25 +356,25 @@ async function generateModelSchema(model: DataModel, project: Project, output: s
383356 ////////////////////////////////////////////////
384357 // 1. Model schema
385358 ////////////////////////////////////////////////
386- let modelSchema = 'baseSchema' ;
359+ let modelSchema = makePartial ( 'baseSchema' ) ;
387360
388361 // omit fields
389- const fieldsToOmit = fields . filter ( ( field ) => hasAttribute ( field , '@omit' ) ) ;
362+ const fieldsToOmit = scalarFields . filter ( ( field ) => hasAttribute ( field , '@omit' ) ) ;
390363 if ( fieldsToOmit . length > 0 ) {
391364 modelSchema = makeOmit (
392365 modelSchema ,
393366 fieldsToOmit . map ( ( f ) => f . name )
394367 ) ;
395368 }
396369
397- if ( allRelationSchema ) {
370+ if ( relationSchema ) {
398371 // export schema with only scalar fields
399372 const modelScalarSchema = `${ upperCaseFirst ( model . name ) } ScalarSchema` ;
400373 writer . writeLine ( `export const ${ modelScalarSchema } = ${ modelSchema } ;` ) ;
401374 modelSchema = modelScalarSchema ;
402375
403376 // merge relations
404- modelSchema = makeMerge ( modelSchema , allRelationSchema ) ;
377+ modelSchema = makeMerge ( modelSchema , makePartial ( relationSchema ) ) ;
405378 }
406379
407380 // refine
@@ -413,10 +386,40 @@ async function generateModelSchema(model: DataModel, project: Project, output: s
413386 writer . writeLine ( `export const ${ upperCaseFirst ( model . name ) } Schema = ${ modelSchema } ;` ) ;
414387
415388 ////////////////////////////////////////////////
416- // 2. Create schema
389+ // 2. Prisma create & update
390+ ////////////////////////////////////////////////
391+
392+ // schema for validating prisma create input (all fields optional)
393+ let prismaCreateSchema = makePartial ( 'baseSchema' ) ;
394+ if ( refineFuncName ) {
395+ prismaCreateSchema = `${ refineFuncName } (${ prismaCreateSchema } )` ;
396+ }
397+ writer . writeLine ( `export const ${ upperCaseFirst ( model . name ) } PrismaCreateSchema = ${ prismaCreateSchema } ;` ) ;
398+
399+ // schema for validating prisma update input (all fields optional)
400+ // note numeric fields can be simple update or atomic operations
401+ let prismaUpdateSchema = `z.object({
402+ ${ scalarFields
403+ . map ( ( field ) => {
404+ let fieldSchema = makeFieldSchema ( field ) ;
405+ if ( field . type . type === 'Int' || field . type . type === 'Float' ) {
406+ fieldSchema = `z.union([${ fieldSchema } , z.record(z.unknown())])` ;
407+ }
408+ return `\t${ field . name } : ${ fieldSchema } ` ;
409+ } )
410+ . join ( ',\n' ) }
411+ })` ;
412+ prismaUpdateSchema = makePartial ( prismaUpdateSchema ) ;
413+ if ( refineFuncName ) {
414+ prismaUpdateSchema = `${ refineFuncName } (${ prismaUpdateSchema } )` ;
415+ }
416+ writer . writeLine ( `export const ${ upperCaseFirst ( model . name ) } PrismaUpdateSchema = ${ prismaUpdateSchema } ;` ) ;
417+
418+ ////////////////////////////////////////////////
419+ // 3. Create schema
417420 ////////////////////////////////////////////////
418421 let createSchema = 'baseSchema' ;
419- const fieldsWithDefault = fields . filter (
422+ const fieldsWithDefault = scalarFields . filter (
420423 ( field ) => hasAttribute ( field , '@default' ) || hasAttribute ( field , '@updatedAt' ) || field . type . array
421424 ) ;
422425 if ( fieldsWithDefault . length > 0 ) {
@@ -426,30 +429,13 @@ async function generateModelSchema(model: DataModel, project: Project, output: s
426429 ) ;
427430 }
428431
429- if ( safeRelationSchema || unsafeRelationSchema ) {
432+ if ( fkSchema ) {
430433 // export schema with only scalar fields
431434 const createScalarSchema = `${ upperCaseFirst ( model . name ) } CreateScalarSchema` ;
432435 writer . writeLine ( `export const ${ createScalarSchema } = ${ createSchema } ;` ) ;
433- createSchema = createScalarSchema ;
434-
435- if ( safeRelationSchema && unsafeRelationSchema ) {
436- // build a union of with relation object fields and with fk fields (mutually exclusive)
437-
438- // TODO: we make all relation fields partial for now because in case of
439- // nested create, not all relation/fk fields are inside payload, need a
440- // better solution
441- createSchema = makeUnion (
442- makeMerge ( createSchema , makePartial ( safeRelationSchema ) ) ,
443- makeMerge ( createSchema , makePartial ( unsafeRelationSchema ) )
444- ) ;
445- } else if ( safeRelationSchema ) {
446- // just relation
447-
448- // TODO: we make all relation fields partial for now because in case of
449- // nested create, not all relation/fk fields are inside payload, need a
450- // better solution
451- createSchema = makeMerge ( createSchema , makePartial ( safeRelationSchema ) ) ;
452- }
436+
437+ // merge fk fields
438+ createSchema = makeMerge ( createScalarSchema , fkSchema ) ;
453439 }
454440
455441 if ( refineFuncName ) {
@@ -465,22 +451,14 @@ async function generateModelSchema(model: DataModel, project: Project, output: s
465451 ////////////////////////////////////////////////
466452 let updateSchema = makePartial ( 'baseSchema' ) ;
467453
468- if ( safeRelationSchema || unsafeRelationSchema ) {
454+ if ( fkSchema ) {
469455 // export schema with only scalar fields
470456 const updateScalarSchema = `${ upperCaseFirst ( model . name ) } UpdateScalarSchema` ;
471457 writer . writeLine ( `export const ${ updateScalarSchema } = ${ updateSchema } ;` ) ;
472458 updateSchema = updateScalarSchema ;
473459
474- if ( safeRelationSchema && unsafeRelationSchema ) {
475- // build a union of with relation object fields and with fk fields (mutually exclusive)
476- updateSchema = makeUnion (
477- makeMerge ( updateSchema , makePartial ( safeRelationSchema ) ) ,
478- makeMerge ( updateSchema , makePartial ( unsafeRelationSchema ) )
479- ) ;
480- } else if ( safeRelationSchema ) {
481- // just relation
482- updateSchema = makeMerge ( updateSchema , makePartial ( safeRelationSchema ) ) ;
483- }
460+ // merge fk fields
461+ updateSchema = makeMerge ( updateSchema , makePartial ( fkSchema ) ) ;
484462 }
485463
486464 if ( refineFuncName ) {
@@ -514,15 +492,3 @@ function makeOmit(schema: string, fields: string[]) {
514492function makeMerge ( schema1 : string , schema2 : string ) : string {
515493 return `${ schema1 } .merge(${ schema2 } )` ;
516494}
517-
518- function makeUnion ( ...schemas : string [ ] ) : string {
519- return `z.union([${ schemas . join ( ', ' ) } ])` ;
520- }
521-
522- function hasForeignKey ( field : DataModelField ) {
523- const relAttr = getAttribute ( field , '@relation' ) ;
524- if ( ! relAttr ) {
525- return false ;
526- }
527- return ! ! getAttributeArg ( relAttr , 'fields' ) ;
528- }
0 commit comments