Skip to content

feat: Match up vertex and fragment locations in render pipeline #1377

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Open
wants to merge 21 commits into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 2 additions & 2 deletions apps/typegpu-docs/src/content/docs/fundamentals/tgsl.mdx
Original file line number Diff line number Diff line change
Expand Up @@ -123,12 +123,12 @@ const root = await tgpu.init();
const backgroundColorUniform = root['~unstable'].createUniform(d.vec4f, d.vec4f(0.114, 0.447, 0.941, 1));

const fragmentTgsl = tgpu['~unstable'].fragmentFn({
out: d.location(0, d.vec4f),
out: d.vec4f,
})(() => backgroundColorUniform.value);
// ^?

const fragmentWgsl = tgpu['~unstable'].fragmentFn({
out: d.location(0, d.vec4f),
out: d.vec4f,
})`{
return backgroundColorUniform;
}`.$uses({ backgroundColorUniform });
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -311,7 +311,7 @@ const vertex = tgpu['~unstable'].vertexFn({

const fragment = tgpu['~unstable'].fragmentFn({
in: { cell: d.f32 },
out: d.location(0, d.vec4f),
out: d.vec4f,
})((input) => {
if (input.cell === -1) {
return d.vec4f(0.5, 0.5, 0.5, 1);
Expand Down
108 changes: 51 additions & 57 deletions packages/typegpu/src/core/function/fnCore.ts
Original file line number Diff line number Diff line change
Expand Up @@ -21,20 +21,19 @@ import {
import { extractArgs } from './extractArgs.ts';
import type { Implementation } from './fnTypes.ts';

export interface TgpuFnShellBase<Args extends unknown[], Return> {
readonly argTypes: Args;
readonly returnType: Return;
readonly isEntry: boolean;
}

export interface FnCore {
applyExternals(newExternals: ExternalMap): void;
resolve(ctx: ResolutionCtx, fnAttribute?: string): string;
resolve(
ctx: ResolutionCtx,
argTypes: unknown[],
returnType: unknown,
fnAttribute?: string,
): string;
}

export function createFnCore(
shell: TgpuFnShellBase<unknown[], unknown>,
implementation: Implementation,
isEntry: boolean,
): FnCore {
/**
* External application has to be deferred until resolution because
Expand All @@ -44,37 +43,34 @@ export function createFnCore(
*/
const externalsToApply: ExternalMap[] = [];

if (typeof implementation === 'string') {
if (!shell.isEntry) {
addArgTypesToExternals(
implementation,
shell.argTypes,
(externals) => externalsToApply.push(externals),
);
addReturnTypeToExternals(
implementation,
shell.returnType,
(externals) => externalsToApply.push(externals),
);
} else {
if (isWgslStruct(shell.argTypes[0])) {
externalsToApply.push({ In: shell.argTypes[0] });
}

if (isWgslStruct(shell.returnType)) {
externalsToApply.push({ Out: shell.returnType });
}
}
}

const core = {
applyExternals(newExternals: ExternalMap): void {
externalsToApply.push(newExternals);
},

resolve(ctx: ResolutionCtx, fnAttribute = ''): string {
resolve(
ctx: ResolutionCtx,
argTypes: unknown[],
returnType: unknown,
fnAttribute = '',
): string {
const externalMap: ExternalMap = {};

if (typeof implementation === 'string') {
if (!isEntry) {
addArgTypesToExternals(
implementation,
argTypes,
(externals) => externalsToApply.push(externals),
);
addReturnTypeToExternals(
implementation,
returnType,
(externals) => externalsToApply.push(externals),
);
}
}

for (const externals of externalsToApply) {
applyExternals(externalMap, externals);
}
Expand All @@ -91,19 +87,19 @@ export function createFnCore(
let header = '';
let body = '';

if (shell.isEntry) {
const input = isWgslStruct(shell.argTypes[0])
? `(in: ${ctx.resolve(shell.argTypes[0])})`
if (isEntry) {
const input = isWgslStruct(argTypes[0])
? `(in: ${ctx.resolve(argTypes[0])})`
: '()';

const attributes = isWgslData(shell.returnType)
? getAttributesString(shell.returnType)
const attributes = isWgslData(returnType)
? getAttributesString(returnType)
: '';
const output = shell.returnType !== Void
? isWgslStruct(shell.returnType)
? `-> ${ctx.resolve(shell.returnType)}`
const output = returnType !== Void
? isWgslStruct(returnType)
? `-> ${ctx.resolve(returnType)}`
: `-> ${attributes !== '' ? attributes : '@location(0)'} ${
ctx.resolve(shell.returnType)
ctx.resolve(returnType)
}`
: '';

Expand All @@ -112,9 +108,9 @@ export function createFnCore(
} else {
const providedArgs = extractArgs(replacedImpl);

if (providedArgs.args.length !== shell.argTypes.length) {
if (providedArgs.args.length !== argTypes.length) {
throw new Error(
`WGSL implementation has ${providedArgs.args.length} arguments, while the shell has ${shell.argTypes.length} arguments.`,
`WGSL implementation has ${providedArgs.args.length} arguments, while the shell has ${argTypes.length} arguments.`,
);
}

Expand All @@ -124,21 +120,19 @@ export function createFnCore(
ctx,
`parameter ${argInfo.identifier}`,
argInfo.type,
shell.argTypes[i],
argTypes[i],
)
}`
).join(', ');

const output = shell.returnType === Void
? ''
: `-> ${
checkAndReturnType(
ctx,
'return type',
providedArgs.ret?.type,
shell.returnType,
)
}`;
const output = returnType === Void ? '' : `-> ${
checkAndReturnType(
ctx,
'return type',
providedArgs.ret?.type,
returnType,
)
}`;

header = `(${input}) ${output}`;

Expand Down Expand Up @@ -171,7 +165,7 @@ export function createFnCore(

// generate wgsl string
const { head, body } = ctx.fnToWgsl({
args: shell.argTypes.map((arg, i) =>
args: argTypes.map((arg, i) =>
snip(
ast.params[i]?.type === FuncParameterType.identifier
? ast.params[i].name
Expand All @@ -186,14 +180,14 @@ export function createFnCore(
alias,
snip(
`_arg_${i}.${name}`,
(shell.argTypes[i] as AnyWgslStruct)
(argTypes[i] as AnyWgslStruct)
.propTypes[name] as AnyWgslData,
),
])
: []
),
),
returnType: shell.returnType as AnyWgslData,
returnType: returnType as AnyWgslData,
body: ast.body,
externalMap,
});
Expand Down
23 changes: 12 additions & 11 deletions packages/typegpu/src/core/function/ioOutputType.ts
Original file line number Diff line number Diff line change
Expand Up @@ -29,12 +29,13 @@ export type IOLayoutToSchema<T extends IOLayout> = T extends BaseData
: never;

export function withLocations<T extends IOData>(
members: IORecord<T>,
members: IORecord<T> | undefined,
locations: Record<string, number> = {},
): WithLocations<IORecord<T>> {
let nextLocation = 0;

return Object.fromEntries(
Object.entries(members).map(([key, member]) => {
Object.entries(members ?? {}).map(([key, member]) => {
if (isBuiltin(member)) {
// Skipping builtins
return [key, member];
Expand All @@ -47,22 +48,22 @@ export function withLocations<T extends IOData>(
return [key, member];
}

return [key, location(nextLocation++, member)];
return [key, location(locations[key] ?? nextLocation++, member)];
}),
);
}

export function createIoSchema<
T extends IOData,
Layout extends IORecord<T> | IOLayout<T>,
>(returnType: Layout) {
>(layout: Layout, locations: Record<string, number> = {}) {
return (
isData(returnType)
? isVoid(returnType)
? returnType
: getCustomLocation(returnType) !== undefined
? returnType
: location(0, returnType)
: struct(withLocations(returnType) as Record<string, T>)
isData(layout)
? isVoid(layout)
? layout
: getCustomLocation(layout) !== undefined
? layout
: location(0, layout)
: struct(withLocations(layout, locations) as Record<string, T>)
) as IOLayoutToSchema<Layout>;
}
4 changes: 3 additions & 1 deletion packages/typegpu/src/core/function/tgpuComputeFn.ts
Original file line number Diff line number Diff line change
Expand Up @@ -151,7 +151,7 @@ function createComputeFn<ComputeIn extends IORecord<AnyComputeBuiltin>>(
[$getNameForward]: FnCore;
};

const core = createFnCore(shell, implementation);
const core = createFnCore(implementation, true);
const inputType = shell.argTypes[0];

const result: This = {
Expand All @@ -174,6 +174,8 @@ function createComputeFn<ComputeIn extends IORecord<AnyComputeBuiltin>>(
'~resolve'(ctx: ResolutionCtx): string {
return core.resolve(
ctx,
shell.argTypes,
shell.returnType,
`@compute @workgroup_size(${workgroupSize.join(', ')}) `,
);
},
Expand Down
8 changes: 4 additions & 4 deletions packages/typegpu/src/core/function/tgpuFn.ts
Original file line number Diff line number Diff line change
@@ -1,8 +1,8 @@
import { type AnyData, snip, UnknownData } from '../../data/dataTypes.ts';
import { Void } from '../../data/wgslTypes.ts';
import { createDualImpl } from '../../shared/generators.ts';
import type { TgpuNamable } from '../../shared/meta.ts';
import { getName, setName } from '../../shared/meta.ts';
import { createDualImpl } from '../../shared/generators.ts';
import type { Infer } from '../../shared/repr.ts';
import {
$getNameForward,
Expand Down Expand Up @@ -171,7 +171,7 @@ function createFn<ImplSchema extends AnyFn>(
[$getNameForward]: FnCore;
};

const core = createFnCore(shell, implementation as Implementation);
const core = createFnCore(implementation as Implementation, false);

const fnBase: This = {
[$internal]: {
Expand Down Expand Up @@ -203,7 +203,7 @@ function createFn<ImplSchema extends AnyFn>(

'~resolve'(ctx: ResolutionCtx): string {
if (typeof implementation === 'string') {
return core.resolve(ctx);
return core.resolve(ctx, shell.argTypes, shell.returnType);
}

const generationCtx = ctx as GenerationCtx;
Expand All @@ -215,7 +215,7 @@ function createFn<ImplSchema extends AnyFn>(

try {
generationCtx.callStack.push(shell.returnType);
return core.resolve(ctx);
return core.resolve(ctx, shell.argTypes, shell.returnType);
} finally {
generationCtx.callStack.pop();
}
Expand Down
40 changes: 27 additions & 13 deletions packages/typegpu/src/core/function/tgpuFragmentFn.ts
Original file line number Diff line number Diff line change
Expand Up @@ -53,8 +53,8 @@ type TgpuFragmentFnShellHeader<
FragmentIn extends FragmentInConstrained,
FragmentOut extends FragmentOutConstrained,
> = {
readonly argTypes: [IOLayoutToSchema<FragmentIn>] | [];
readonly targets: FragmentOut;
readonly in: FragmentIn | undefined;
readonly out: FragmentOut;
readonly returnType: IOLayoutToSchema<FragmentOut>;
readonly isEntry: true;
};
Expand Down Expand Up @@ -150,10 +150,8 @@ export function fragmentFn<
out: FragmentOut;
}): TgpuFragmentFnShell<FragmentIn, FragmentOut> {
const shell: TgpuFragmentFnShellHeader<FragmentIn, FragmentOut> = {
argTypes: options.in && Object.keys(options.in).length !== 0
? [createIoSchema(options.in)]
: [],
targets: options.out,
in: options.in,
out: options.out,
returnType: createIoSchema(options.out),
isEntry: true,
};
Expand Down Expand Up @@ -181,9 +179,8 @@ function createFragmentFn(
): TgpuFragmentFn {
type This = TgpuFragmentFn & SelfResolvable & { [$getNameForward]: FnCore };

const core = createFnCore(shell, implementation);
const core = createFnCore(implementation, true);
const outputType = shell.returnType;
const inputType = shell.argTypes[0];
if (typeof implementation === 'string') {
addReturnTypeToExternals(
implementation,
Expand All @@ -207,15 +204,27 @@ function createFragmentFn(
if (isNamable(outputType)) {
outputType.$name(`${newLabel}_Output`);
}
if (isNamable(inputType)) {
inputType.$name(`${newLabel}_Input`);
}
return this;
},

'~resolve'(ctx: ResolutionCtx): string {
const inputWithLocation = shell.in
? createIoSchema(shell.in, ctx.varyingLocations)
.$name(`${getName(this) ?? ''}_Input`)
: undefined;

if (inputWithLocation) {
core.applyExternals({ In: inputWithLocation });
}
core.applyExternals({ Out: outputType });

if (typeof implementation === 'string') {
return core.resolve(ctx, '@fragment ');
return core.resolve(
ctx,
inputWithLocation ? [inputWithLocation] : [],
shell.returnType,
'@fragment ',
);
}

const generationCtx = ctx as GenerationCtx;
Expand All @@ -227,7 +236,12 @@ function createFragmentFn(

try {
generationCtx.callStack.push(outputType);
return core.resolve(ctx, '@fragment ');
return core.resolve(
ctx,
inputWithLocation ? [inputWithLocation] : [],
shell.returnType,
'@fragment ',
);
} finally {
generationCtx.callStack.pop();
}
Expand Down
Loading