Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
11 changes: 8 additions & 3 deletions packages/plugins/tanstack-query/src/generator.ts
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@ import {
ensureEmptyDir,
generateModelMeta,
getDataModels,
getPrismaClientGenerator,
isDelegateModel,
requireOption,
resolvePath,
Expand Down Expand Up @@ -52,7 +53,6 @@ export async function generate(model: Model, options: PluginOptions, dmmf: DMMF.
`Invalid value for "portable" option: ${options.portable}, a boolean value is expected`
);
}
const portable = options.portable ?? false;

await generateModelMeta(project, models, typeDefs, {
output: path.join(outDir, '__model_meta.ts'),
Expand All @@ -70,8 +70,13 @@ export async function generate(model: Model, options: PluginOptions, dmmf: DMMF.
generateModelHooks(target, version, project, outDir, dataModel, mapping, options);
});

if (portable) {
generateBundledTypes(project, outDir, options);
if (options.portable) {
const gen = getPrismaClientGenerator(model);
if (gen?.isNewGenerator) {
warnings.push(`The "portable" option is not supported with the "prisma-client" generator and is ignored.`);
} else {
generateBundledTypes(project, outDir, options);
}
}

await saveProject(project);
Expand Down
14 changes: 13 additions & 1 deletion packages/schema/src/cli/actions/generate.ts
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
import { PluginError } from '@zenstackhq/sdk';
import { getPrismaClientGenerator, PluginError } from '@zenstackhq/sdk';
import { isPlugin } from '@zenstackhq/sdk/ast';
import colors from 'colors';
import path from 'path';
Expand Down Expand Up @@ -70,6 +70,18 @@ async function runPlugins(options: Options) {

const model = await loadDocument(schema);

const gen = getPrismaClientGenerator(model);
if (gen?.isNewGenerator && !options.output) {
console.error(
colors.red(
'When using the "prisma-client" generator, you must provide an explicit output path with the "--output" CLI parameter.'
)
);
throw new CliError(
'When using with the "prisma-client" generator, you must provide an explicit output path with the "--output" CLI parameter.'
);
}

for (const name of [...(options.withPlugins ?? []), ...(options.withoutPlugins ?? [])]) {
const pluginDecl = model.declarations.find((d) => isPlugin(d) && d.name === name);
if (!pluginDecl) {
Expand Down
7 changes: 2 additions & 5 deletions packages/schema/src/cli/plugin-runner.ts
Original file line number Diff line number Diff line change
Expand Up @@ -112,10 +112,7 @@ export class PluginRunner {
const otherPlugins = plugins.filter((p) => !p.options.preprocessor);

// calculate all plugins (including core plugins implicitly enabled)
const { corePlugins, userPlugins } = this.calculateAllPlugins(
runnerOptions,
otherPlugins,
);
const { corePlugins, userPlugins } = this.calculateAllPlugins(runnerOptions, otherPlugins);
const allPlugins = [...corePlugins, ...userPlugins];

// check dependencies
Expand Down Expand Up @@ -448,7 +445,7 @@ export class PluginRunner {
}

async function compileProject(project: Project, runnerOptions: PluginRunnerOptions) {
if (runnerOptions.compile !== false) {
if (!runnerOptions.output && runnerOptions.compile !== false) {
// emit
await emitProject(project);
} else {
Expand Down
72 changes: 55 additions & 17 deletions packages/schema/src/plugins/enhancer/enhance/index.ts
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,7 @@ import {
getDataModelAndTypeDefs,
getDataModels,
getForeignKeyFields,
getLiteral,
getPrismaClientGenerator,
getRelationField,
hasAttribute,
isDelegateModel,
Expand All @@ -22,7 +22,6 @@ import {
ReferenceExpr,
isArrayExpr,
isDataModel,
isGeneratorDecl,
isTypeDef,
type Model,
} from '@zenstackhq/sdk/ast';
Expand Down Expand Up @@ -56,7 +55,7 @@ import { generateTypeDefType } from './model-typedef-generator';
// information of delegate models and their sub models
type DelegateInfo = [DataModel, DataModel[]][];

const LOGICAL_CLIENT_GENERATION_PATH = './.logical-prisma-client';
const LOGICAL_CLIENT_GENERATION_PATH = './logical-prisma-client';

export class EnhancerGenerator {
// regex for matching "ModelCreateXXXInput" and "ModelUncheckedCreateXXXInput" type
Expand Down Expand Up @@ -114,6 +113,9 @@ export class EnhancerGenerator {
if (this.needsLogicalClient) {
prismaTypesFixed = true;
resultPrismaTypeImport = LOGICAL_CLIENT_GENERATION_PATH;
if (this.isNewPrismaClientGenerator) {
resultPrismaTypeImport += '/client';
}
const result = await this.generateLogicalPrisma();
dmmf = result.dmmf;
}
Expand Down Expand Up @@ -440,23 +442,14 @@ export type Enhanced<Client> =
}

private getPrismaClientGeneratorName(model: Model) {
for (const generator of model.declarations.filter(isGeneratorDecl)) {
if (
generator.fields.some(
(f) => f.name === 'provider' && getLiteral<string>(f.value) === 'prisma-client-js'
)
) {
return generator.name;
}
const gen = getPrismaClientGenerator(model);
if (!gen) {
throw new PluginError(name, `Cannot find "prisma-client-js" or "prisma-client" generator in the schema`);
}
throw new PluginError(name, `Cannot find prisma-client-js generator in the schema`);
return gen.name;
}

private async processClientTypes(prismaClientDir: string) {
// make necessary updates to the generated `index.d.ts` file and overwrite it
const project = new Project();
const sf = project.addSourceFileAtPath(path.join(prismaClientDir, 'index.d.ts'));

// build a map of delegate models and their sub models
const delegateInfo: DelegateInfo = [];
this.model.declarations
Expand All @@ -468,6 +461,16 @@ export type Enhanced<Client> =
}
});

if (this.isNewPrismaClientGenerator) {
await this.processClientTypesNewPrismaGenerator(prismaClientDir, delegateInfo);
} else {
await this.processClientTypesLegacyPrismaGenerator(prismaClientDir, delegateInfo);
}
}
private async processClientTypesLegacyPrismaGenerator(prismaClientDir: string, delegateInfo: DelegateInfo) {
const project = new Project();
const sf = project.addSourceFileAtPath(path.join(prismaClientDir, 'index.d.ts'));

// transform index.d.ts and write it into a new file (better perf than in-line editing)
const sfNew = project.createSourceFile(path.join(prismaClientDir, 'index-fixed.d.ts'), undefined, {
overwrite: true,
Expand All @@ -484,6 +487,36 @@ export type Enhanced<Client> =
await sfNew.save();
}

private async processClientTypesNewPrismaGenerator(prismaClientDir: string, delegateInfo: DelegateInfo) {
const project = new Project();

for (const d of this.model.declarations.filter(isDataModel)) {
const fileName = `${prismaClientDir}/models/${d.name}.ts`;
const sf = project.addSourceFileAtPath(fileName);
const sfNew = project.createSourceFile(`${prismaClientDir}/models/${d.name}-fixed.ts`, undefined, {
overwrite: true,
});

const syntaxList = sf.getChildren()[0];
if (!Node.isSyntaxList(syntaxList)) {
throw new PluginError(name, `Unexpected syntax list structure in ${fileName}`);
}

syntaxList.getChildren().forEach((node) => {
if (Node.isInterfaceDeclaration(node)) {
sfNew.addInterface(this.transformInterface(node, delegateInfo));
} else if (Node.isTypeAliasDeclaration(node)) {
sfNew.addTypeAlias(this.transformTypeAlias(node, delegateInfo));
} else {
sfNew.addStatements(node.getText());
}
});

await sfNew.move(sf.getFilePath(), { overwrite: true });
await sfNew.save();
}
}

private transformPrismaTypes(sf: SourceFile, sfNew: SourceFile, delegateInfo: DelegateInfo) {
// copy toplevel imports
sfNew.addImportDeclarations(sf.getImportDeclarations().map((n) => n.getStructure()));
Expand Down Expand Up @@ -639,7 +672,7 @@ export type Enhanced<Client> =
source = `${payloadRecord[1]
.map(
(concrete) =>
`($${concrete.name}Payload<ExtArgs> & { scalars: { ${discriminatorDecl.name}: '${concrete.name}' } })`
`(Prisma.$${concrete.name}Payload<ExtArgs> & { scalars: { ${discriminatorDecl.name}: '${concrete.name}' } })`
)
.join(' | ')}`;
}
Expand Down Expand Up @@ -916,4 +949,9 @@ export type Enhanced<Client> =
private trimEmptyLines(source: string): string {
return source.replace(/^\s*[\r\n]/gm, '');
}

private get isNewPrismaClientGenerator() {
const gen = getPrismaClientGenerator(this.model);
return !!gen?.isNewGenerator;
}
}
20 changes: 10 additions & 10 deletions packages/schema/src/plugins/prisma/index.ts
Original file line number Diff line number Diff line change
Expand Up @@ -2,11 +2,10 @@ import {
PluginError,
type PluginFunction,
type PluginOptions,
getLiteral,
getPrismaClientGenerator,
normalizedRelative,
resolvePath,
} from '@zenstackhq/sdk';
import { GeneratorDecl, isGeneratorDecl } from '@zenstackhq/sdk/ast';
import { getDMMF } from '@zenstackhq/sdk/prisma';
import colors from 'colors';
import fs from 'fs';
Expand Down Expand Up @@ -58,13 +57,9 @@ const run: PluginFunction = async (model, options, _dmmf, _globalOptions) => {
}

// extract user-provided prisma client output path
const generator = model.declarations.find(
(d): d is GeneratorDecl =>
isGeneratorDecl(d) &&
d.fields.some((f) => f.name === 'provider' && getLiteral(f.value) === 'prisma-client-js')
);
const clientOutputField = generator?.fields.find((f) => f.name === 'output');
const clientOutput = getLiteral<string>(clientOutputField?.value);
const gen = getPrismaClientGenerator(model);
const clientOutput = gen?.output;
const newGenerator = !!gen?.isNewGenerator;

if (clientOutput) {
if (path.isAbsolute(clientOutput)) {
Expand All @@ -81,6 +76,11 @@ const run: PluginFunction = async (model, options, _dmmf, _globalOptions) => {
clientOutputDir = prismaClientPath;
}

if (newGenerator) {
// "prisma-client" generator requires an extra "/client" import suffix
prismaClientPath = `${prismaClientPath}/client`;
}

// get PrismaClient dts path

if (clientOutput) {
Expand All @@ -89,7 +89,7 @@ const run: PluginFunction = async (model, options, _dmmf, _globalOptions) => {
prismaClientDtsPath = path.resolve(path.dirname(options.schemaPath), clientOutputDir, 'index.d.ts');
}

if (!prismaClientDtsPath || !fs.existsSync(prismaClientDtsPath)) {
if (!newGenerator && (!prismaClientDtsPath || !fs.existsSync(prismaClientDtsPath))) {
// if the file does not exist, try node module resolution
try {
// the resolution is relative to the schema path by default
Expand Down
5 changes: 4 additions & 1 deletion packages/schema/src/plugins/prisma/schema-generator.ts
Original file line number Diff line number Diff line change
Expand Up @@ -234,7 +234,10 @@ export class PrismaSchemaGenerator {

// deal with configuring PrismaClient preview features
const provider = generator.fields.find((f) => f.name === 'provider');
if (provider?.text === JSON.stringify('prisma-client-js')) {
if (
provider?.text === JSON.stringify('prisma-client-js') ||
provider?.text === JSON.stringify('prisma-client')
) {
const prismaVersion = getPrismaVersion();
if (prismaVersion) {
const previewFeatures = JSON.parse(
Expand Down
31 changes: 30 additions & 1 deletion packages/sdk/src/utils.ts
Original file line number Diff line number Diff line change
Expand Up @@ -466,7 +466,11 @@ export function getPreviewFeatures(model: Model) {
const jsGenerator = model.declarations.find(
(d) =>
isGeneratorDecl(d) &&
d.fields.some((f) => f.name === 'provider' && getLiteral<string>(f.value) === 'prisma-client-js')
d.fields.some(
(f) =>
(f.name === 'provider' && getLiteral<string>(f.value) === 'prisma-client-js') ||
getLiteral<string>(f.value) === 'prisma-client'
)
) as GeneratorDecl | undefined;

if (jsGenerator) {
Expand Down Expand Up @@ -683,3 +687,28 @@ export function getRelationName(field: DataModelField) {
}
return getAttributeArgLiteral(relAttr, 'name');
}

export function getPrismaClientGenerator(model: Model) {
const decl = model.declarations.find(
(d): d is GeneratorDecl =>
isGeneratorDecl(d) &&
d.fields.some(
(f) =>
f.name === 'provider' &&
(getLiteral<string>(f.value) === 'prisma-client-js' ||
getLiteral<string>(f.value) === 'prisma-client')
)
);
if (!decl) {
return undefined;
}

const provider = getLiteral<string>(decl.fields.find((f) => f.name === 'provider')?.value);
return {
name: decl.name,
output: getLiteral<string>(decl.fields.find((f) => f.name === 'output')?.value),
previewFeatures: getLiteralArray<string>(decl.fields.find((f) => f.name === 'previewFeatures')?.value),
provider,
isNewGenerator: provider === 'prisma-client',
};
}
42 changes: 21 additions & 21 deletions packages/testtools/src/schema.ts
Original file line number Diff line number Diff line change
Expand Up @@ -278,27 +278,6 @@ export async function loadSchema(schema: string, options?: SchemaLoadOptions) {
fs.cpSync(dep, path.join(projectDir, 'node_modules', pkgJson.name), { recursive: true, force: true });
});

const prismaLoadPath = options?.prismaLoadPath
? path.isAbsolute(options.prismaLoadPath)
? options.prismaLoadPath
: path.join(projectDir, options.prismaLoadPath)
: path.join(projectDir, 'node_modules/.prisma/client');
const prismaModule = require(prismaLoadPath);
const PrismaClient = prismaModule.PrismaClient;

let clientOptions: object = { log: ['info', 'warn', 'error'] };
if (options?.prismaClientOptions) {
clientOptions = { ...clientOptions, ...options.prismaClientOptions };
}
let prisma = new PrismaClient(clientOptions);
// https://github.com/prisma/prisma/issues/18292
prisma[Symbol.for('nodejs.util.inspect.custom')] = 'PrismaClient';

if (opt.pulseApiKey) {
const withPulse = loadModule('@prisma/extension-pulse/node', projectDir).withPulse;
prisma = prisma.$extends(withPulse({ apiKey: opt.pulseApiKey }));
}

opt.extraSourceFiles?.forEach(({ name, content }) => {
fs.writeFileSync(path.join(projectDir, name), content);
});
Expand All @@ -325,6 +304,27 @@ export async function loadSchema(schema: string, options?: SchemaLoadOptions) {
run('npx tsc --project tsconfig.json');
}

const prismaLoadPath = options?.prismaLoadPath
? path.isAbsolute(options.prismaLoadPath)
? options.prismaLoadPath
: path.join(projectDir, options.prismaLoadPath)
: path.join(projectDir, 'node_modules/.prisma/client');
const prismaModule = require(prismaLoadPath);
const PrismaClient = prismaModule.PrismaClient;

let clientOptions: object = { log: ['info', 'warn', 'error'] };
if (options?.prismaClientOptions) {
clientOptions = { ...clientOptions, ...options.prismaClientOptions };
}
let prisma = new PrismaClient(clientOptions);
// https://github.com/prisma/prisma/issues/18292
prisma[Symbol.for('nodejs.util.inspect.custom')] = 'PrismaClient';

if (opt.pulseApiKey) {
const withPulse = loadModule('@prisma/extension-pulse/node', projectDir).withPulse;
prisma = prisma.$extends(withPulse({ apiKey: opt.pulseApiKey }));
}

if (options?.getPrismaOnly) {
return {
prisma,
Expand Down
Loading
Loading