Skip to content

Conversation

@siq1
Copy link
Contributor

@siq1 siq1 commented May 27, 2025

The goal of this PR is to allow (arbitrary) reshape and transpose in Zkcuda. It will solve #133, and will utilize #161.

For overall mechanism, please refer to #133.


For circuit developers, this overall experience only differs a little. Here's an example:

Previous code:

let kernel_add_2: Kernel<M31Config> = compile_add_2_macro().unwrap();
let kernel_add_16: Kernel<M31Config> = compile_add_16_macro().unwrap();

let mut ctx: Context<M31Config> = Context::default();
let a = ctx.copy_to_device(&a, false);
let mut b = None;
call_kernel!(ctx, kernel_add_2, a, mut b);
let b = b.reshape(&[1, 16]);
let mut c = None;
call_kernel!(ctx, kernel_add_16, b, mut c);
let c = c.reshape(&[]);
let result: M31 = ctx.copy_to_host(c);
assert_eq!(result, M31::from(32 * 33 / 2));

let computation_graph = ctx.to_computation_graph();
let (prover_setup, verifier_setup) =
    ExpanderGKRProvingSystem::<M31Config>::setup(&computation_graph);
let proof = ExpanderGKRProvingSystem::<M31Config>::prove(
    &prover_setup,
    &computation_graph,
    &ctx.device_memories,
);
assert!(ExpanderGKRProvingSystem::<M31Config>::verify(
    &verifier_setup,
    &computation_graph,
    &proof
));

Current code:

// Kernel -> KernelPrimitive
let kernel_add_2_tmp: KernelPrimitive<M31Config> = compile_add_2_macro().unwrap();
let kernel_add_16: KernelPrimitive<M31Config> = compile_add_16_macro().unwrap();

let mut ctx: Context<M31Config> = Context::default();
let a = ctx.copy_to_device(&a); // Removed argument is_broadcast 
let mut b = None;
call_kernel!(ctx, kernel_add_2, 16, a, mut b).unwrap(); // Additional argument num_parallel
let b = b.reshape(&[1, 16]);
let mut c = None;
call_kernel!(ctx, kernel_add_16, 1, b, mut c).unwrap();
let c = c.reshape(&[]);
let result: M31 = ctx.copy_to_host(c);
assert_eq!(result, M31::from(32 * 33 / 2));

let computation_graph = ctx.compile_computation_graph().unwrap(); // Function name and behavior changed
ctx.solve_witness().unwrap(); // Additional step
let (prover_setup, verifier_setup) =
    ExpanderGKRProvingSystem::<M31Config>::setup(&computation_graph);
let proof = ExpanderGKRProvingSystem::<M31Config>::prove(
    &prover_setup,
    &computation_graph,
    &ctx.export_device_memories(), // Need to export now
);
assert!(ExpanderGKRProvingSystem::<M31Config>::verify(
    &verifier_setup,
    &computation_graph,
    &proof
));

Execution times also differs:

Particular Step Previous function Current function
Compile Func -> IR compile_xx_kernel compile_xx_kernel
Compile IR -> layered circuit compile_xx_kernel compile_computation_graph
Eval circuit for output call_kernel call_kernel
Eval circuit for hint witness call_kernel solve_witness

@siq1 siq1 changed the title Zkcuda new API Zkcuda new API and compilation refactor May 30, 2025
@siq1 siq1 marked this pull request as ready for review June 11, 2025 20:54
@siq1 siq1 requested a review from zhiyong1997 June 11, 2025 20:54
@siq1 siq1 mentioned this pull request Jun 11, 2025
@siq1 siq1 merged commit 3633da7 into master Jun 19, 2025
22 checks passed
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

3 participants