Skip to content
This repository was archived by the owner on Apr 23, 2025. It is now read-only.

Commit 2a4b5cc

Browse files
authored
fix: add enum import to zod generation (zenstackhq#528)
1 parent 3aa0f51 commit 2a4b5cc

File tree

2 files changed

+26
-8
lines changed

2 files changed

+26
-8
lines changed

packages/schema/src/plugins/zod/generator.ts

Lines changed: 24 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -7,24 +7,28 @@ import {
77
emitProject,
88
getDataModels,
99
getLiteral,
10+
getPrismaClientImportSpec,
1011
hasAttribute,
12+
isEnumFieldReference,
1113
isForeignKeyField,
1214
resolvePath,
1315
saveProject,
1416
} from '@zenstackhq/sdk';
15-
import { DataModel, DataSource, Model, isDataModel, isDataSource, isEnum } from '@zenstackhq/sdk/ast';
17+
import { DataModel, DataSource, EnumField, Model, isDataModel, isDataSource, isEnum } from '@zenstackhq/sdk/ast';
1618
import {
1719
AggregateOperationSupport,
1820
addMissingInputObjectTypes,
1921
resolveAggregateOperationSupport,
2022
} from '@zenstackhq/sdk/dmmf-helpers';
2123
import { promises as fs } from 'fs';
24+
import { streamAllContents } from 'langium';
2225
import path from 'path';
2326
import { Project } from 'ts-morph';
27+
import { upperCaseFirst } from 'upper-case-first';
28+
import { isFromStdlib } from '../../language-server/utils';
2429
import { getDefaultOutputFolder } from '../plugin-utils';
2530
import Transformer from './transformer';
2631
import removeDir from './utils/removeDir';
27-
import { upperCaseFirst } from 'upper-case-first';
2832
import { makeFieldSchema, makeValidationRefinements } from './utils/schema-gen';
2933

3034
export async function generate(model: Model, options: PluginOptions, dmmf: DMMF.Document) {
@@ -176,8 +180,6 @@ async function generateModelSchema(model: DataModel, project: Project, output: s
176180
overwrite: true,
177181
});
178182
sf.replaceWithText((writer) => {
179-
writer.writeLine('/* eslint-disable */');
180-
181183
const fields = model.fields.filter(
182184
(field) =>
183185
!AUXILIARY_FIELDS.includes(field.name) &&
@@ -186,9 +188,25 @@ async function generateModelSchema(model: DataModel, project: Project, output: s
186188
!isForeignKeyField(field)
187189
);
188190

191+
writer.writeLine('/* eslint-disable */');
189192
writer.writeLine(`import { z } from 'zod';`);
190193

191-
// import enums
194+
// import user-defined enums from Prisma as they might be referenced in the expressions
195+
const importEnums = new Set<string>();
196+
for (const node of streamAllContents(model)) {
197+
if (isEnumFieldReference(node)) {
198+
const field = node.target.ref as EnumField;
199+
if (!isFromStdlib(field.$container)) {
200+
importEnums.add(field.$container.name);
201+
}
202+
}
203+
}
204+
if (importEnums.size > 0) {
205+
const prismaImport = getPrismaClientImportSpec(model.$container, path.join(output, 'models'));
206+
writer.writeLine(`import { ${[...importEnums].join(', ')} } from '${prismaImport}';`);
207+
}
208+
209+
// import enum schemas
192210
for (const field of fields) {
193211
if (field.type.reference?.ref && isEnum(field.type.reference?.ref)) {
194212
const name = upperCaseFirst(field.type.reference?.ref.name);
@@ -205,9 +223,9 @@ async function generateModelSchema(model: DataModel, project: Project, output: s
205223
});
206224
writer.writeLine(');');
207225

226+
// compile "@@validate" to ".refine"
208227
const refinements = makeValidationRefinements(model);
209228
if (refinements.length > 0) {
210-
console.log('Generated refinements:', refinements);
211229
writer.writeLine(`function refine(schema: z.ZodType) { return schema${refinements.join('\n')}; }`);
212230
}
213231

packages/sdk/src/utils.ts

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -133,8 +133,8 @@ export function getAttributeArgLiteral<T extends string | number | boolean>(
133133
return undefined;
134134
}
135135

136-
export function isEnumFieldReference(expr: Expression): expr is ReferenceExpr {
137-
return isReferenceExpr(expr) && isEnumField(expr.target.ref);
136+
export function isEnumFieldReference(node: AstNode): node is ReferenceExpr {
137+
return isReferenceExpr(node) && isEnumField(node.target.ref);
138138
}
139139

140140
/**

0 commit comments

Comments
 (0)