diff --git a/co-noir/co-builder/src/acir_format.rs b/co-noir/co-builder/src/acir_format.rs index f99eca14..b95b6c83 100644 --- a/co-noir/co-builder/src/acir_format.rs +++ b/co-noir/co-builder/src/acir_format.rs @@ -1,7 +1,7 @@ use acir::{ acir_field::GenericFieldElement, circuit::{ - opcodes::{BlackBoxFuncCall, MemOp}, + opcodes::{BlackBoxFuncCall, FunctionInput, MemOp}, Circuit, }, native_types::{Expression, Witness, WitnessMap}, @@ -11,8 +11,8 @@ use ark_ff::{PrimeField, Zero}; use std::collections::{BTreeMap, HashMap, HashSet}; use crate::types::types::{ - AcirFormatOriginalOpcodeIndices, BlockConstraint, BlockType, MulQuad, PolyTriple, - RangeConstraint, RecursionConstraint, + AcirFormatOriginalOpcodeIndices, BlockConstraint, BlockType, LogicConstraint, MulQuad, + PolyTriple, RangeConstraint, RecursionConstraint, WitnessOrConstant, }; #[derive(Default)] @@ -28,7 +28,7 @@ pub struct AcirFormat { // using PolyTripleConstraint = bb::poly_triple_; pub public_inputs: Vec, pub(crate) range_constraints: Vec, - // std::vector logic_constraints; + pub(crate) logic_constraints: Vec>, // std::vector aes128_constraints; // std::vector sha256_constraints; // std::vector sha256_compression; @@ -579,6 +579,17 @@ impl AcirFormat { block.trace.push(acir_mem_op); } + fn parse_input(input: FunctionInput>) -> WitnessOrConstant { + match input.input() { + acir::circuit::opcodes::ConstantOrWitnessEnum::Witness(witness) => { + WitnessOrConstant::from_index(witness.0) + } + acir::circuit::opcodes::ConstantOrWitnessEnum::Constant(constant) => { + WitnessOrConstant::from_constant(constant.into_repr()) + } + } + } + fn handle_blackbox_func_call( arg: BlackBoxFuncCall>, af: &mut AcirFormat, @@ -592,16 +603,34 @@ impl AcirFormat { key: _, outputs: _, } => todo!("BlackBoxFuncCall::AES128Encrypt"), - BlackBoxFuncCall::AND { - lhs: _, - rhs: _, - output: _, - } => todo!("BlackBoxFuncCall::AND"), - BlackBoxFuncCall::XOR { - lhs: _, - rhs: _, - output: _, - } => todo!("BlackBoxFuncCall::XOR"), + BlackBoxFuncCall::AND { lhs, rhs, output } => { + let lhs_input = Self::parse_input(lhs); + let rhs_input = Self::parse_input(rhs); + af.logic_constraints.push(LogicConstraint::and_gate( + lhs_input, + rhs_input, + output.0, + lhs.num_bits(), + )); + af.constrained_witness.insert(output.0); + af.original_opcode_indices + .logic_constraints + .push(opcode_index); + } + BlackBoxFuncCall::XOR { lhs, rhs, output } => { + let lhs_input = Self::parse_input(lhs); + let rhs_input = Self::parse_input(rhs); + af.logic_constraints.push(LogicConstraint::xor_gate( + lhs_input, + rhs_input, + output.0, + lhs.num_bits(), + )); + af.constrained_witness.insert(output.0); + af.original_opcode_indices + .logic_constraints + .push(opcode_index); + } BlackBoxFuncCall::RANGE { input } => { let witness_input = input.to_witness().witness_index(); af.range_constraints.push(RangeConstraint { diff --git a/co-noir/co-builder/src/types/types.rs b/co-noir/co-builder/src/types/types.rs index b32ac703..4f2f8965 100644 --- a/co-noir/co-builder/src/types/types.rs +++ b/co-noir/co-builder/src/types/types.rs @@ -97,7 +97,7 @@ pub(crate) struct BlockConstraint { #[derive(Default)] pub(crate) struct AcirFormatOriginalOpcodeIndices { - // pub(crate) logic_constraints: Vec, + pub(crate) logic_constraints: Vec, pub(crate) range_constraints: Vec, // pub(crate) aes128_constraints: Vec, // pub(crate) sha256_constraints: Vec, @@ -400,6 +400,46 @@ pub(crate) struct RangeConstraint { pub(crate) num_bits: u32, } +pub(crate) struct LogicConstraint { + pub(crate) a: WitnessOrConstant, + pub(crate) b: WitnessOrConstant, + pub(crate) result: u32, + pub(crate) num_bits: u32, + pub(crate) is_xor_gate: bool, +} + +impl LogicConstraint { + pub(crate) fn and_gate( + a: WitnessOrConstant, + b: WitnessOrConstant, + result: u32, + num_bits: u32, + ) -> Self { + Self { + a, + b, + result, + num_bits, + is_xor_gate: false, + } + } + + pub(crate) fn xor_gate( + a: WitnessOrConstant, + b: WitnessOrConstant, + result: u32, + num_bits: u32, + ) -> Self { + Self { + a, + b, + result, + num_bits, + is_xor_gate: true, + } + } +} + pub(crate) struct RecursionConstraint { // An aggregation state is represented by two G1 affine elements. Each G1 point has // two field element coordinates (x, y). Thus, four field elements @@ -1338,3 +1378,28 @@ impl PermutationMapping { Self { sigmas, ids } } } + +#[derive(Clone, Debug, PartialEq, Eq)] +pub(crate) struct WitnessOrConstant { + index: u32, + value: F, + is_constant: bool, +} + +impl WitnessOrConstant { + pub(crate) fn from_index(index: u32) -> Self { + Self { + index, + value: F::zero(), + is_constant: false, + } + } + + pub(crate) fn from_constant(constant: F) -> Self { + Self { + index: 0, + value: constant, + is_constant: true, + } + } +}