Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

WIP: pod2 circuit in pex #31

Draft
wants to merge 16 commits into
base: main
Choose a base branch
from
Prev Previous commit
Next Next commit
resolve conflict
  • Loading branch information
ax0 authored and ludns committed Nov 9, 2024
commit 1c00a9b0f85084f3bdd4f8411f01fc25b4dbd7ef
153 changes: 119 additions & 34 deletions pod2/src/pod/circuit/operation.rs
Original file line number Diff line number Diff line change
@@ -1,18 +1,23 @@
use anyhow::Result;
use plonky2::{
field::goldilocks_field::GoldilocksField,
hash::poseidon::PoseidonHash,
iop::{
target::Target,
witness::{PartialWitness, WitnessWrite},
},
plonk::circuit_builder::CircuitBuilder,
};
use std::{array, collections::HashMap};
use std::iter::zip;
use std::{array, collections::HashMap};

use crate::{
pod::{
circuit::util::statement_matrix_ref, gadget::GadgetID, operation::{OpList, Operation as Op}, statement::{StatementOrRef, StatementRef}, GPGInput, OpCmd, Statement
circuit::util::statement_matrix_ref,
gadget::GadgetID,
operation::{OpList, Operation as Op},
statement::{StatementOrRef, StatementRef},
GPGInput, OpCmd, Statement,
},
D, F, NUM_BITS, POD,
};
Expand All @@ -21,7 +26,7 @@ use super::{
entry::EntryTarget,
origin::OriginTarget,
statement::{StatementRefTarget, StatementTarget},
util::assert_less_if,
util::{and, assert_less_if, member},
};

#[derive(Clone, Copy, Debug)]
Expand All @@ -31,7 +36,7 @@ pub struct OperationTarget<const VL: usize> {
pub operand2: StatementRefTarget,
pub operand3: StatementRefTarget,
pub entry: EntryTarget,
pub contains_proof: [Target; VL]
pub contains_proof: [Target; VL],
}

impl<const VL: usize> OperationTarget<VL> {
Expand All @@ -42,7 +47,7 @@ impl<const VL: usize> OperationTarget<VL> {
operand2: StatementRefTarget::new_virtual(builder),
operand3: StatementRefTarget::new_virtual(builder),
entry: EntryTarget::new_virtual(builder),
contains_proof: builder.add_virtual_target_arr()
contains_proof: builder.add_virtual_target_arr(),
}
}
// TODO: Perestroika!
Expand All @@ -51,7 +56,7 @@ impl<const VL: usize> OperationTarget<VL> {
pw: &mut PartialWitness<GoldilocksField>,
operation: &Op<StatementRef>,
ref_index_map: &HashMap<StatementRef, (usize, usize)>,
statement_table: &<StatementRef as StatementOrRef>::StatementTable
statement_table: &<StatementRef as StatementOrRef>::StatementTable,
) -> Result<()> {
let operation_as_fields = operation.to_fields::<VL>(ref_index_map, &statement_table)?;
pw.set_target(self.op, operation_as_fields[0])?;
Expand All @@ -71,10 +76,7 @@ impl<const VL: usize> OperationTarget<VL> {
&[self.entry.key, self.entry.value],
&[operation_as_fields[7], operation_as_fields[8]],
)?;
pw.set_target_arr(
&self.contains_proof,
&operation_as_fields[9..]
)?;
pw.set_target_arr(&self.contains_proof, &operation_as_fields[9..])?;
Ok(())
}

Expand Down Expand Up @@ -154,12 +156,11 @@ impl<const VL: usize> OperationTarget<VL> {
StatementTarget::not_equal(builder, statement1_target, statement2_target), // LtToNonequality. TODO.
];

// Indicators of whether the conditions on the operands were satisfied.
let statements_are_value_ofs = {
let s1_check = statement1_target.has_code(builder, Statement::VALUE_OF);
let s2_check = statement2_target.has_code(builder, Statement::VALUE_OF);
builder.and(s1_check, s2_check)
};
// Type indicators
let statement_is_valueof = [statement1_target, statement2_target, statement3_target]
.iter()
.map(|s_target| s_target.has_code(builder, Statement::VALUE_OF))
.collect::<Vec<_>>();

let statements_1_and_2_equal =
builder.is_equal(statement1_target.value, statement2_target.value);
Expand All @@ -180,7 +181,7 @@ impl<const VL: usize> OperationTarget<VL> {
let s2_check = statement2_target.has_code(builder, Statement::EQUAL);
builder.and(s1_check, s2_check)
};
// TODO
// TODO: Proper origin check
let statements_allow_transitivity = {
let origins_match = builder.is_equal(
statement1_target.origin2.origin_id,
Expand All @@ -190,20 +191,100 @@ impl<const VL: usize> OperationTarget<VL> {
builder.and(origins_match, keys_match)
};

// Do a membership check for `ContainsFromEntries`
// TODO: Type check args.
let scalar_is_member = member(builder, statement2_target.value, &self.contains_proof);
let proof_root = builder
.hash_n_to_hash_no_pad::<PoseidonHash>(self.contains_proof.to_vec())
.elements[0];
let root_is_valid = builder.is_equal(proof_root, statement1_target.value);

let op_is_valid = [
builder._true(), // None - no checks needed.
builder._true(), // NewEntry - no checks needed.
builder._true(), // Copy - no checks needed.
statements_1_and_2_equal, // EqualityFromEntries - equality check
builder._true(), // None - no checks needed.
builder._true(), // NewEntry - no checks needed.
builder._true(), // Copy - no checks needed.
and(
builder,
&[
statement_is_valueof[0],
statement_is_valueof[1],
statements_1_and_2_equal,
],
), // EqualityFromEntries - equality check
builder.not(statements_1_and_2_equal), // NonequalityFromEntries - non-equality check
statements_are_value_ofs, // GtFromEntries - Type-check input statements
and(builder, &[statement_is_valueof[0], statement_is_valueof[1]]), // GtFromEntries - Type-check input statements. TODO: Replace assertion above.
builder.and(statements_are_equalities, statements_allow_transitivity), // TransitiveEqualityFromStatements
statement1_target.has_code(builder, Statement::GT), // GtToNonequality
builder._true(), // TODO: ContainsFromEntries
builder._true(), // TODO: RenameContainedBy
builder._true(), // TODO: SumOf
builder._true(), // TODO: ProductOf
builder._true(), // TODO: MaxOf
builder.and(scalar_is_member, root_is_valid), // TODO: ContainsFromEntries
{
let conditions = &[
// Types
statement1_target.has_code(builder, Statement::CONTAINS),
statement2_target.has_code(builder, Statement::EQUAL),
// Anchored key equality. TODO.
builder.is_equal(statement1_target.key1, statement2_target.key1),
builder.is_equal(
statement1_target.origin1.origin_id,
statement2_target.origin1.origin_id,
),
];
and(builder, conditions)
}, // RenameContainedBy. TODO.
{
let conditions = &[
// Types
statement_is_valueof[0],
statement_is_valueof[1],
statement_is_valueof[2],
// s1 = s2 + s3
{
let rhs = builder.add(statement2_target.value, statement3_target.value);
builder.is_equal(statement1_target.value, rhs)
},
];
and(builder, conditions)
}, // SumOf
{
let conditions = &[
// Types
statement_is_valueof[0],
statement_is_valueof[1],
statement_is_valueof[2],
// s1 = s2 * s3
{
let rhs = builder.mul(statement2_target.value, statement3_target.value);
builder.is_equal(statement1_target.value, rhs)
},
];
and(builder, conditions)
}, // ProductOf
{
let conditions = &[
// Types
statement_is_valueof[0],
statement_is_valueof[1],
statement_is_valueof[2],
// s1 = max(s2, s3) <=> s1 >= s2, s3 and (s1 = s2 or s1 = s3)
{
// TODO: Replace assertions.
let maxof_opcode_target = builder.constant(Op::<Statement>::MAX_OF);
let one_target = builder.one();
let op_is_maxof = builder.is_equal(self.op, maxof_opcode_target);
let s1 = statement1_target.value;
let s1_plus_one = builder.add(statement1_target.value, one_target);
let s2 = statement2_target.value;
let s3 = statement3_target.value;
assert_less_if::<NUM_BITS>(builder, op_is_maxof, s2, s1_plus_one);
assert_less_if::<NUM_BITS>(builder, op_is_maxof, s3, s1_plus_one);

let s1_eq_s2 = builder.is_equal(s1, s2);
let s1_eq_s3 = builder.is_equal(s1, s3);

builder.or(s1_eq_s2, s1_eq_s3)
},
];
and(builder, conditions)
}, // MaxOf
builder._true(), // TODO: Lt
builder._true(), // TODO: LtToNonequality
]
Expand Down Expand Up @@ -241,14 +322,16 @@ pub struct OpListTarget<const NS: usize, const VL: usize>(pub [OperationTarget<V

impl<const NS: usize, const VL: usize> OpListTarget<NS, VL> {
pub fn new_virtual(builder: &mut CircuitBuilder<F, D>) -> Self {
OpListTarget(array::from_fn(|_| OperationTarget::<VL>::new_virtual(builder)))
OpListTarget(array::from_fn(|_| {
OperationTarget::<VL>::new_virtual(builder)
}))
}

pub fn set_witness(
&self,
pw: &mut PartialWitness<GoldilocksField>,
op_list: &OpList,
gpg_input: &GPGInput
gpg_input: &GPGInput,
) -> Result<()> {
// TODO: Abstract this away.
// Determine output POD statements for the purposes of later reference
Expand All @@ -263,13 +346,15 @@ impl<const NS: usize, const VL: usize> OpListTarget<NS, VL> {
// Set operation targets
let ref_index_map = StatementRef::index_map(&input_and_output_pod_list);
let statement_table: <StatementRef as StatementOrRef>::StatementTable =
input_and_output_pod_list.iter().map(
|(pod_name, pod)|
(pod_name.clone(), pod.payload.statements_map.clone())
).collect();
input_and_output_pod_list
.iter()
.map(|(pod_name, pod)| (pod_name.clone(), pod.payload.statements_map.clone()))
.collect();

zip(&self.0, op_list.sort(&input_and_output_pod_list).0).try_for_each(
|(op_target, OpCmd(op, _))| op_target.set_witness(pw, &op, &ref_index_map, &statement_table),
|(op_target, OpCmd(op, _))| {
op_target.set_witness(pw, &op, &ref_index_map, &statement_table)
},
)
}
}
9 changes: 7 additions & 2 deletions pod2/src/pod/circuit/util.rs
Original file line number Diff line number Diff line change
Expand Up @@ -113,8 +113,8 @@ pub fn statement_matrix_ref(
))
}

/// Less than assertion for targets known to fit within `num_bits` bits. This assumption is
/// also checked here.
/// Less than assertion for targets known to fit within `num_bits`
/// bits. This assumption is also checked here.
pub fn assert_less<const NUM_BITS: usize>(
builder: &mut CircuitBuilder<F, D>,
x: Target,
Expand Down Expand Up @@ -148,3 +148,8 @@ pub fn member(builder: &mut CircuitBuilder<F, D>, x: Target, v: &[Target]) -> Bo
builder.or(acc, eq_x_y)
})
}

pub fn and(builder: &mut CircuitBuilder<F, D>, v: &[BoolTarget]) -> BoolTarget {
v.iter()
.fold(builder._true(), |acc, ind| builder.and(acc, *ind))
}
41 changes: 26 additions & 15 deletions pod2/src/pod/gadget/opexecutor.rs
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,12 @@ use std::iter::zip;

use crate::{
pod::{
circuit::operation::OpListTarget, gadget::GadgetID, operation::{OpList, OperationCmd}, payload::StatementList, statement::{StatementOrRef, StatementRef}, GPGInput, POD
circuit::operation::OpListTarget,
gadget::GadgetID,
operation::{OpList, OperationCmd},
payload::StatementList,
statement::{StatementOrRef, StatementRef},
GPGInput, POD,
},
recursion::OpsExecutorTrait,
D, F,
Expand Down Expand Up @@ -53,7 +58,7 @@ impl<const NP: usize, const NS: usize, const VL: usize> OpsExecutorTrait
type Targets = (
[[StatementTarget; NS]; NP],
[Vec<Target>; NP],
OpListTarget<NS,VL>,
OpListTarget<NS, VL>,
[StatementTarget; NS], // registered as public input
);

Expand All @@ -64,8 +69,7 @@ impl<const NP: usize, const NS: usize, const VL: usize> OpsExecutorTrait
let origin_id_map_target: [Vec<Target>; NP] =
array::from_fn(|_| builder.add_virtual_targets(NS + 2));

let op_list_target =
OpListTarget::new_virtual(builder);
let op_list_target = OpListTarget::new_virtual(builder);

// TODO: Check that origin ID map has appropriate properties.

Expand Down Expand Up @@ -149,7 +153,7 @@ impl<const NP: usize, const NS: usize, const VL: usize> OpsExecutorTrait
// Set op list target.
let op_list_target = &targets.2;
op_list_target.set_witness(pw, &input.1, &input.0)?;

// Check output statement list target
zip(&targets.3, output).try_for_each(|(s_target, (_, s))| s_target.set_witness(pw, s))?;

Expand Down Expand Up @@ -188,6 +192,8 @@ mod tests {
// Input Schnorr PODs. For now, they must all have the same number
// of statements.
const NS: usize = 3;
const VL: usize = 3;

let schnorr_pod1_name = "Test POD 1".to_string();
let schnorr_pod1 = POD::execute_schnorr_gadget::<NS>(
&[
Expand Down Expand Up @@ -217,7 +223,10 @@ mod tests {
let schnorr_pod4_name = "Test POD 4".to_string();
let schnorr_pod4 = POD::execute_schnorr_gadget::<NS>(
&[
Entry::new_from_scalar("who", GoldilocksField(7)),
Entry::new_from_vec(
"who",
vec![GoldilocksField(5), GoldilocksField(6), GoldilocksField(7)],
),
Entry::new_from_scalar("what", GoldilocksField(5)),
],
&SchnorrSecretKey { sk: 20 },
Expand All @@ -230,6 +239,8 @@ mod tests {
(schnorr_pod4_name.clone(), schnorr_pod4),
];

const NP: usize = 4;

// Ops
let op_lists = [
OpList(vec![
Expand Down Expand Up @@ -274,13 +285,13 @@ mod tests {
]),
OpList(vec![
OpCmd::new(Op::None, "cons"),
// OpCmd::new(
// Op::ContainsFromEntries(
// StatementRef::new(&schnorr_pod4_name, "VALUEOF:who"),
// StatementRef::new(&schnorr_pod4_name, "VALUEOF:what"),
// ),
// "car",
// ),
OpCmd::new(
Op::ContainsFromEntries(
StatementRef::new(&schnorr_pod4_name, "VALUEOF:who"),
StatementRef::new(&schnorr_pod4_name, "VALUEOF:what"),
),
"car",
),
OpCmd::new(Op::None, "cdr"),
]),
]
Expand All @@ -297,8 +308,8 @@ mod tests {
let oracle_pod = POD::execute_oracle_gadget(&gpg_input, &op_list.0)?;

// circuit test
let targets = OpExecutorGadget::<4, 3, 0>::add_targets(&mut builder)?;
OpExecutorGadget::<4, 3, 0>::set_targets(
let targets = OpExecutorGadget::<NP, NS, VL>::add_targets(&mut builder)?;
OpExecutorGadget::<NP, NS, VL>::set_targets(
&mut pw,
&targets,
&(gpg_input.clone(), op_list),
Expand Down
10 changes: 9 additions & 1 deletion pod2/src/pod/operation.rs
Original file line number Diff line number Diff line change
Expand Up @@ -377,6 +377,14 @@ impl OpList {
NS
));
}
Ok(Self([self.0, (op_list_len..NS).map( |i| OperationCmd( Operation::None, format!("_DUMMY_STATEMENT{}", i))).collect()].concat()))
Ok(Self(
[
self.0,
(op_list_len..NS)
.map(|i| OperationCmd(Operation::None, format!("_DUMMY_STATEMENT{}", i)))
.collect(),
]
.concat(),
))
}
}