Skip to content

Commit c68df70

Browse files
committed
fix(zod): "strip" mode causes create payload fields to be accidentally dropped
Fixes #1746
1 parent 1197c70 commit c68df70

File tree

6 files changed

+346
-32
lines changed

6 files changed

+346
-32
lines changed

packages/runtime/package.json

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -76,6 +76,10 @@
7676
"./models": {
7777
"types": "./models.d.ts"
7878
},
79+
"./zod-utils": {
80+
"types": "./zod-utils.d.ts",
81+
"default": "./zod-utils.js"
82+
},
7983
"./package.json": {
8084
"default": "./package.json"
8185
}

packages/runtime/src/zod-utils.ts

Lines changed: 121 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,121 @@
1+
/* eslint-disable @typescript-eslint/no-explicit-any */
2+
import { z as Z } from 'zod';
3+
4+
/**
5+
* A smarter version of `z.union` that decide which candidate to use based on how few unrecognized keys it has.
6+
*
7+
* The helper is used to deal with ambiguity in union generated for Prisma inputs when the zod schemas are configured
8+
* to run in "strip" object parsing mode. Since "strip" automatically drops unrecognized keys, it may result in
9+
* accidentally matching a less-ideal schema candidate.
10+
*
11+
* The helper uses a custom schema to find the candidate that results in the fewest unrecognized keys when parsing the data.
12+
*/
13+
export function smartUnion(z: typeof Z, candidates: Z.ZodSchema[]) {
14+
// strip `z.lazy`
15+
const processedCandidates = candidates.map((candidate) => unwrapLazy(z, candidate));
16+
17+
if (processedCandidates.some((c) => !(c instanceof z.ZodObject || c instanceof z.ZodArray))) {
18+
// fall back to plain union if not all candidates are objects or arrays
19+
return z.union(candidates as any);
20+
}
21+
22+
let resultData: any;
23+
24+
return z
25+
.custom((data) => {
26+
if (Array.isArray(data)) {
27+
const { data: result, success } = smartArrayUnion(
28+
z,
29+
processedCandidates.filter((c) => c instanceof z.ZodArray),
30+
data
31+
);
32+
if (success) {
33+
resultData = result;
34+
}
35+
return success;
36+
} else {
37+
const { data: result, success } = smartObjectUnion(
38+
z,
39+
processedCandidates.filter((c) => c instanceof z.ZodObject),
40+
data
41+
);
42+
if (success) {
43+
resultData = result;
44+
}
45+
return success;
46+
}
47+
})
48+
.transform(() => {
49+
// return the parsed data
50+
return resultData;
51+
});
52+
}
53+
54+
function smartArrayUnion(z: typeof Z, candidates: Array<Z.ZodArray<Z.ZodObject<Z.ZodRawShape>>>, data: any) {
55+
if (candidates.length === 0) {
56+
return { data: undefined, success: false };
57+
}
58+
59+
if (!Array.isArray(data)) {
60+
return { data: undefined, success: false };
61+
}
62+
63+
if (data.length === 0) {
64+
return { data, success: true };
65+
}
66+
67+
// use the first element to identify the candidate schema to use
68+
const item = data[0];
69+
const itemSchema = identifyCandidate(
70+
z,
71+
candidates.map((candidate) => candidate.element),
72+
item
73+
);
74+
75+
// find the matching schema and re-parse the data
76+
const schema = candidates.find((candidate) => candidate.element === itemSchema);
77+
return schema!.safeParse(data);
78+
}
79+
80+
function smartObjectUnion(z: typeof Z, candidates: Z.ZodObject<Z.ZodRawShape>[], data: any) {
81+
if (candidates.length === 0) {
82+
return { data: undefined, success: false };
83+
}
84+
const schema = identifyCandidate(z, candidates, data);
85+
return schema.safeParse(data);
86+
}
87+
88+
function identifyCandidate(
89+
z: typeof Z,
90+
candidates: Array<Z.ZodObject<Z.ZodRawShape> | Z.ZodLazy<Z.ZodObject<Z.ZodRawShape>>>,
91+
data: any
92+
) {
93+
const strictResults = candidates.map((candidate) => {
94+
// make sure to strip `z.lazy` before parsing
95+
const unwrapped = unwrapLazy(z, candidate);
96+
return {
97+
schema: candidate,
98+
// force object schema to run in strict mode to capture unrecognized keys
99+
result: unwrapped.strict().safeParse(data),
100+
};
101+
});
102+
103+
// find the schema with the fewest unrecognized keys
104+
const { schema } = strictResults.sort((a, b) => {
105+
const aCount = countUnrecognizedKeys(a.result.error?.issues ?? []);
106+
const bCount = countUnrecognizedKeys(b.result.error?.issues ?? []);
107+
return aCount - bCount;
108+
})[0];
109+
return schema;
110+
}
111+
112+
function countUnrecognizedKeys(issues: Z.ZodIssue[]) {
113+
return issues
114+
.filter((issue) => issue.code === 'unrecognized_keys')
115+
.map((issue) => issue.keys.length)
116+
.reduce((a, b) => a + b, 0);
117+
}
118+
119+
function unwrapLazy<T extends Z.ZodSchema>(z: typeof Z, schema: T | Z.ZodLazy<T>): T {
120+
return schema instanceof z.ZodLazy ? schema.schema : schema;
121+
}

packages/schema/src/plugins/zod/generator.ts

Lines changed: 18 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -23,11 +23,13 @@ import { upperCaseFirst } from 'upper-case-first';
2323
import { name } from '.';
2424
import { getDefaultOutputFolder } from '../plugin-utils';
2525
import Transformer from './transformer';
26+
import { ObjectMode } from './types';
2627
import { makeFieldSchema, makeValidationRefinements } from './utils/schema-gen';
2728

2829
export class ZodSchemaGenerator {
2930
private readonly sourceFiles: SourceFile[] = [];
3031
private readonly globalOptions: PluginGlobalOptions;
32+
private readonly mode: ObjectMode;
3133

3234
constructor(
3335
private readonly model: Model,
@@ -39,6 +41,19 @@ export class ZodSchemaGenerator {
3941
throw new Error('Global options are required');
4042
}
4143
this.globalOptions = globalOptions;
44+
45+
// options validation
46+
if (
47+
this.options.mode &&
48+
(typeof this.options.mode !== 'string' || !['strip', 'strict', 'passthrough'].includes(this.options.mode))
49+
) {
50+
throw new PluginError(
51+
name,
52+
`Invalid mode option: "${this.options.mode}". Must be one of 'strip', 'strict', or 'passthrough'.`
53+
);
54+
}
55+
56+
this.mode = (this.options.mode ?? 'strict') as ObjectMode;
4257
}
4358

4459
async generate() {
@@ -55,17 +70,6 @@ export class ZodSchemaGenerator {
5570
ensureEmptyDir(output);
5671
Transformer.setOutputPath(output);
5772

58-
// options validation
59-
if (
60-
this.options.mode &&
61-
(typeof this.options.mode !== 'string' || !['strip', 'strict', 'passthrough'].includes(this.options.mode))
62-
) {
63-
throw new PluginError(
64-
name,
65-
`Invalid mode option: "${this.options.mode}". Must be one of 'strip', 'strict', or 'passthrough'.`
66-
);
67-
}
68-
6973
// calculate the models to be excluded
7074
const excludeModels = this.getExcludedModels();
7175

@@ -120,6 +124,7 @@ export class ZodSchemaGenerator {
120124
project: this.project,
121125
inputObjectTypes,
122126
zmodel: this.model,
127+
mode: this.mode,
123128
});
124129
await transformer.generateInputSchemas(this.options, this.model);
125130
this.sourceFiles.push(...transformer.sourceFiles);
@@ -215,6 +220,7 @@ export class ZodSchemaGenerator {
215220
project: this.project,
216221
inputObjectTypes: [],
217222
zmodel: this.model,
223+
mode: this.mode,
218224
});
219225
await transformer.generateEnumSchemas();
220226
this.sourceFiles.push(...transformer.sourceFiles);
@@ -243,6 +249,7 @@ export class ZodSchemaGenerator {
243249
project: this.project,
244250
inputObjectTypes,
245251
zmodel: this.model,
252+
mode: this.mode,
246253
});
247254
const moduleName = transformer.generateObjectSchema(generateUnchecked, this.options);
248255
moduleNames.push(moduleName);

packages/schema/src/plugins/zod/transformer.ts

Lines changed: 51 additions & 21 deletions
Original file line numberDiff line numberDiff line change
@@ -7,7 +7,7 @@ import path from 'path';
77
import type { Project, SourceFile } from 'ts-morph';
88
import { upperCaseFirst } from 'upper-case-first';
99
import { computePrismaClientImport } from './generator';
10-
import { AggregateOperationSupport, TransformerParams } from './types';
10+
import { AggregateOperationSupport, ObjectMode, TransformerParams } from './types';
1111

1212
export default class Transformer {
1313
name: string;
@@ -28,6 +28,7 @@ export default class Transformer {
2828
private inputObjectTypes: PrismaDMMF.InputType[];
2929
public sourceFiles: SourceFile[] = [];
3030
private zmodel: Model;
31+
private mode: ObjectMode;
3132

3233
constructor(params: TransformerParams) {
3334
this.originalName = params.name ?? '';
@@ -40,6 +41,7 @@ export default class Transformer {
4041
this.project = params.project;
4142
this.inputObjectTypes = params.inputObjectTypes;
4243
this.zmodel = params.zmodel;
44+
this.mode = params.mode;
4345
}
4446

4547
static setOutputPath(outPath: string) {
@@ -73,7 +75,12 @@ export default class Transformer {
7375
}
7476

7577
generateImportZodStatement() {
76-
return "import { z } from 'zod';\n";
78+
let r = "import { z } from 'zod';\n";
79+
if (this.mode === 'strip') {
80+
// import the additional `smartUnion` helper
81+
r += `import { smartUnion } from '@zenstackhq/runtime/zod-utils';\n`;
82+
}
83+
return r;
7784
}
7885

7986
generateExportSchemaStatement(name: string, schema: string) {
@@ -210,8 +217,19 @@ export default class Transformer {
210217

211218
const opt = !field.isRequired ? '.optional()' : '';
212219

213-
let resString =
214-
alternatives.length === 1 ? alternatives.join(',\r\n') : `z.union([${alternatives.join(',\r\n')}])${opt}`;
220+
let resString: string;
221+
222+
if (alternatives.length === 1) {
223+
resString = alternatives.join(',\r\n');
224+
} else {
225+
if (alternatives.some((alt) => alt.includes('Unchecked'))) {
226+
// if the union is for combining checked and unchecked input types, use `smartUnion`
227+
// to parse with the best candidate at runtime
228+
resString = this.wrapWithSmartUnion(...alternatives) + `${opt}`;
229+
} else {
230+
resString = `z.union([${alternatives.join(',\r\n')}])${opt}`;
231+
}
232+
}
215233

216234
if (field.isNullable) {
217235
resString += '.nullable()';
@@ -391,17 +409,6 @@ export const ${this.name}ObjectSchema: SchemaType = ${schema} as SchemaType;`;
391409
return `${modelName}InputSchema.${queryName}`;
392410
}
393411

394-
wrapWithZodUnion(zodStringFields: string[]) {
395-
let wrapped = '';
396-
397-
wrapped += 'z.union([';
398-
wrapped += '\n';
399-
wrapped += ' ' + zodStringFields.join(',');
400-
wrapped += '\n';
401-
wrapped += '])';
402-
return wrapped;
403-
}
404-
405412
wrapWithZodObject(zodStringFields: string | string[], mode = 'strict') {
406413
let wrapped = '';
407414

@@ -425,6 +432,14 @@ export const ${this.name}ObjectSchema: SchemaType = ${schema} as SchemaType;`;
425432
return wrapped;
426433
}
427434

435+
wrapWithSmartUnion(...schemas: string[]) {
436+
if (this.mode === 'strip') {
437+
return `smartUnion(z, [${schemas.join(', ')}])`;
438+
} else {
439+
return `z.union([${schemas.join(', ')}])`;
440+
}
441+
}
442+
428443
async generateInputSchemas(options: PluginOptions, zmodel: Model) {
429444
const globalExports: string[] = [];
430445

@@ -464,7 +479,7 @@ export const ${this.name}ObjectSchema: SchemaType = ${schema} as SchemaType;`;
464479
this.resolveSelectIncludeImportAndZodSchemaLine(model);
465480

466481
let imports = [
467-
`import { z } from 'zod'`,
482+
this.generateImportZodStatement(),
468483
this.generateImportPrismaStatement(options),
469484
selectImport,
470485
includeImport,
@@ -523,7 +538,10 @@ export const ${this.name}ObjectSchema: SchemaType = ${schema} as SchemaType;`;
523538
);
524539
}
525540
const dataSchema = generateUnchecked
526-
? `z.union([${modelName}CreateInputObjectSchema, ${modelName}UncheckedCreateInputObjectSchema])`
541+
? this.wrapWithSmartUnion(
542+
`${modelName}CreateInputObjectSchema`,
543+
`${modelName}UncheckedCreateInputObjectSchema`
544+
)
527545
: `${modelName}CreateInputObjectSchema`;
528546
const fields = `${selectZodSchemaLineLazy} ${includeZodSchemaLineLazy} data: ${dataSchema}`;
529547
codeBody += `create: ${this.wrapWithZodObject(fields, mode)},`;
@@ -568,7 +586,10 @@ export const ${this.name}ObjectSchema: SchemaType = ${schema} as SchemaType;`;
568586
);
569587
}
570588
const dataSchema = generateUnchecked
571-
? `z.union([${modelName}UpdateInputObjectSchema, ${modelName}UncheckedUpdateInputObjectSchema])`
589+
? this.wrapWithSmartUnion(
590+
`${modelName}UpdateInputObjectSchema`,
591+
`${modelName}UncheckedUpdateInputObjectSchema`
592+
)
572593
: `${modelName}UpdateInputObjectSchema`;
573594
const fields = `${selectZodSchemaLineLazy} ${includeZodSchemaLineLazy} data: ${dataSchema}, where: ${modelName}WhereUniqueInputObjectSchema`;
574595
codeBody += `update: ${this.wrapWithZodObject(fields, mode)},`;
@@ -586,7 +607,10 @@ export const ${this.name}ObjectSchema: SchemaType = ${schema} as SchemaType;`;
586607
);
587608
}
588609
const dataSchema = generateUnchecked
589-
? `z.union([${modelName}UpdateManyMutationInputObjectSchema, ${modelName}UncheckedUpdateManyInputObjectSchema])`
610+
? this.wrapWithSmartUnion(
611+
`${modelName}UpdateManyMutationInputObjectSchema`,
612+
`${modelName}UncheckedUpdateManyInputObjectSchema`
613+
)
590614
: `${modelName}UpdateManyMutationInputObjectSchema`;
591615
const fields = `data: ${dataSchema}, where: ${modelName}WhereInputObjectSchema.optional()`;
592616
codeBody += `updateMany: ${this.wrapWithZodObject(fields, mode)},`;
@@ -606,10 +630,16 @@ export const ${this.name}ObjectSchema: SchemaType = ${schema} as SchemaType;`;
606630
);
607631
}
608632
const createSchema = generateUnchecked
609-
? `z.union([${modelName}CreateInputObjectSchema, ${modelName}UncheckedCreateInputObjectSchema])`
633+
? this.wrapWithSmartUnion(
634+
`${modelName}CreateInputObjectSchema`,
635+
`${modelName}UncheckedCreateInputObjectSchema`
636+
)
610637
: `${modelName}CreateInputObjectSchema`;
611638
const updateSchema = generateUnchecked
612-
? `z.union([${modelName}UpdateInputObjectSchema, ${modelName}UncheckedUpdateInputObjectSchema])`
639+
? this.wrapWithSmartUnion(
640+
`${modelName}UpdateInputObjectSchema`,
641+
`${modelName}UncheckedUpdateInputObjectSchema`
642+
)
613643
: `${modelName}UpdateInputObjectSchema`;
614644
const fields = `${selectZodSchemaLineLazy} ${includeZodSchemaLineLazy} where: ${modelName}WhereUniqueInputObjectSchema, create: ${createSchema}, update: ${updateSchema}`;
615645
codeBody += `upsert: ${this.wrapWithZodObject(fields, mode)},`;

packages/schema/src/plugins/zod/types.ts

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -14,6 +14,7 @@ export type TransformerParams = {
1414
project: Project;
1515
inputObjectTypes: PrismaDMMF.InputType[];
1616
zmodel: Model;
17+
mode: ObjectMode;
1718
};
1819

1920
export type AggregateOperationSupport = {
@@ -25,3 +26,5 @@ export type AggregateOperationSupport = {
2526
avg?: boolean;
2627
};
2728
};
29+
30+
export type ObjectMode = 'strict' | 'strip' | 'passthrough';

0 commit comments

Comments
 (0)