Skip to content

Commit

Permalink
fix: get input scales from graph for da creation (zkonduit#529)
Browse files Browse the repository at this point in the history
  • Loading branch information
ethan-crypto authored Oct 6, 2023
1 parent 619ca5b commit 5800eb5
Show file tree
Hide file tree
Showing 11 changed files with 34 additions and 41 deletions.
17 changes: 2 additions & 15 deletions abis/DataAttestation.json
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,7 @@
},
{
"internalType": "uint256[]",
"name": "_outputScales",
"name": "_scales",
"type": "uint256[]"
},
{
Expand All @@ -35,19 +35,6 @@
"stateMutability": "nonpayable",
"type": "constructor"
},
{
"inputs": [],
"name": "INPUT_SCALE",
"outputs": [
{
"internalType": "uint256",
"name": "",
"type": "uint256"
}
],
"stateMutability": "view",
"type": "function"
},
{
"inputs": [
{
Expand Down Expand Up @@ -106,7 +93,7 @@
"type": "uint256"
}
],
"name": "outputScales",
"name": "scales",
"outputs": [
{
"internalType": "uint256",
Expand Down
14 changes: 5 additions & 9 deletions contracts/AttestData.sol
Original file line number Diff line number Diff line change
Expand Up @@ -28,8 +28,7 @@ contract DataAttestation {
}
AccountCall[] public accountCalls;

uint public constant INPUT_SCALE = 1 << 0;
uint[] public outputScales;
uint[] public scales;

address public admin;

Expand All @@ -55,13 +54,13 @@ contract DataAttestation {
address[] memory _contractAddresses,
bytes[][] memory _callData,
uint256[][] memory _decimals,
uint[] memory _outputScales,
uint[] memory _scales,
uint8 _instanceOffset,
address _admin
) {
admin = _admin;
for (uint i; i < _outputScales.length; i++) {
outputScales.push(1 << _outputScales[i]);
for (uint i; i < _scales.length; i++) {
scales.push(1 << _scales[i]);
}
populateAccountCalls(_contractAddresses, _callData, _decimals);
instanceOffset = _instanceOffset;
Expand Down Expand Up @@ -239,10 +238,7 @@ contract DataAttestation {
account,
accountCalls[i].callData[j]
);
uint256 scale = INPUT_SCALE;
if (counter >= INPUT_CALLS) {
scale = outputScales[counter - INPUT_CALLS];
}
uint256 scale = scales[counter];
int256 quantized_data = quantizeData(
returnData,
accountCalls[i].decimals[j],
Expand Down
2 changes: 1 addition & 1 deletion examples/onnx/mnist_gan/settings.json
Original file line number Diff line number Diff line change
@@ -1 +1 @@
{"run_args":{"tolerance":{"val":0.0,"scale":1.0},"input_scale":7,"param_scale":7,"scale_rebase_multiplier":10,"bits":16,"logrows":17,"variables":[["batch_size",1]],"input_visibility":"Private","output_visibility":"Public","param_visibility":"Public"},"num_constraints":8928903,"total_const_size":8753605,"model_instance_shapes":[[1,28,28]],"model_output_scales":[42],"module_sizes":{"poseidon":[0,[0]],"elgamal":[0,[0]]},"required_lookups":[{"Sigmoid":{"scale":4398046500000.0}},{"Exp":{"scale":2097152.0}},{"Exp":{"scale":34359740000.0}},{"GreaterThan":{"a":0.0}}],"check_mode":"UNSAFE","version":"0.0.0","num_blinding_factors":null}
{"run_args":{"tolerance":{"val":0.0,"scale":1.0},"input_scale":7,"param_scale":7,"scale_rebase_multiplier":10,"bits":16,"logrows":17,"variables":[["batch_size",1]],"input_visibility":"Private","output_visibility":"Public","param_visibility":"Public"},"num_constraints":8928903,"total_const_size":8753605,"model_instance_shapes":[[1,28,28]],"model_output_scales":[42],"model_input_scales":[7],"module_sizes":{"poseidon":[0,[0]],"elgamal":[0,[0]]},"required_lookups":[{"Sigmoid":{"scale":4398046500000.0}},{"Exp":{"scale":2097152.0}},{"Exp":{"scale":34359740000.0}},{"GreaterThan":{"a":0.0}}],"check_mode":"UNSAFE","version":"0.0.0","num_blinding_factors":null}
2 changes: 1 addition & 1 deletion examples/onnx/variable_cnn/settings.json
Original file line number Diff line number Diff line change
@@ -1 +1 @@
{"run_args":{"tolerance":{"val":0.0,"scales":[1,1]},"input_scale":11,"param_scale":11,"scale_rebase_multiplier":1,"bits":25,"logrows":26,"variables":[["batch_size",1]],"input_visibility":"Private","output_visibility":"Public","param_visibility":"Private"},"num_constraints":176820,"total_const_size":0,"model_instance_shapes":[[1,100]],"model_output_scales":[11],"module_sizes":{"poseidon":[0,[0]],"elgamal":[0,[0]]},"required_lookups":[{"Div":{"denom":2048.0}},"ReLU"],"check_mode":"UNSAFE","version":"0.0.0"}
{"run_args":{"tolerance":{"val":0.0,"scales":[1,1]},"input_scale":11,"param_scale":11,"scale_rebase_multiplier":1,"bits":25,"logrows":26,"variables":[["batch_size",1]],"input_visibility":"Private","output_visibility":"Public","param_visibility":"Private"},"num_constraints":176820,"total_const_size":0,"model_instance_shapes":[[1,100]],"model_output_scales":[11],"model_input_scales":[11],"module_sizes":{"poseidon":[0,[0]],"elgamal":[0,[0]]},"required_lookups":[{"Div":{"denom":2048.0}},"ReLU"],"check_mode":"UNSAFE","version":"0.0.0"}
21 changes: 12 additions & 9 deletions src/eth.rs
Original file line number Diff line number Diff line change
Expand Up @@ -126,8 +126,17 @@ pub async fn deploy_da_verifier_via_solidity(
let mut contract_instance_offset = 0;

if let DataSource::OnChain(source) = input.input_data {
let input_scales = settings.model_input_scales;
for call in source.calls {
calls_to_accounts.push(call);
}

// give each input a scale
for scale in input_scales {
scales.extend(vec![
scale;
instance_shapes[instance_idx].iter().product::<usize>()
]);
instance_idx += 1;
}
} else if let DataSource::File(source) = input.input_data {
Expand Down Expand Up @@ -644,7 +653,7 @@ pub fn get_contract_artifacts(

/// Sets the constants stored in the da verifier
pub fn fix_da_sol(
input_data: Option<(u32, Vec<CallsToAccount>)>,
input_data: Option<Vec<CallsToAccount>>,
output_data: Option<Vec<CallsToAccount>>,
) -> Result<String, Box<dyn Error>> {

Expand All @@ -653,14 +662,8 @@ pub fn fix_da_sol(
// fill in the quantization params and total calls
// as constants to the contract to save on gas
if let Some(input_data) = input_data {
let input_calls: usize = input_data.1.iter().map(|v| v.call_data.len()).sum();
let input_scale = input_data.0;
accounts_len = input_data.1.len();
contract = contract.replace(
"uint public constant INPUT_SCALE = 1 << 0;",
&format!("uint public constant INPUT_SCALE = 1 << {};", input_scale),
);

let input_calls: usize = input_data.iter().map(|v| v.call_data.len()).sum();
accounts_len = input_data.len();
contract = contract.replace(
"uint256 constant INPUT_CALLS = 0;",
&format!("uint256 constant INPUT_CALLS = {};", input_calls),
Expand Down
7 changes: 4 additions & 3 deletions src/execute.rs
Original file line number Diff line number Diff line change
Expand Up @@ -169,7 +169,7 @@ pub async fn run(cli: Cli) -> Result<(), Box<dyn Error>> {
sol_code_path,
abi_path,
data,
} => create_evm_data_attestation_verifier(
} => create_evm_data_attestation(
vk_path,
srs_path,
settings_path,
Expand Down Expand Up @@ -615,6 +615,7 @@ pub(crate) async fn calibrate(
run_args: found_run_args,
required_lookups: settings.required_lookups,
model_output_scales: settings.model_output_scales,
model_input_scales: settings.model_input_scales,
num_constraints: settings.num_constraints,
total_const_size: settings.total_const_size,
..original_settings.clone()
Expand Down Expand Up @@ -843,7 +844,7 @@ pub(crate) fn create_evm_verifier(
}

#[cfg(not(target_arch = "wasm32"))]
pub(crate) fn create_evm_data_attestation_verifier(
pub(crate) fn create_evm_data_attestation(
vk_path: PathBuf,
srs_path: PathBuf,
settings_path: PathBuf,
Expand Down Expand Up @@ -893,7 +894,7 @@ pub(crate) fn create_evm_data_attestation_verifier(
for call in source.calls {
on_chain_input_data.push(call);
}
Some((settings.run_args.input_scale, on_chain_input_data))
Some(on_chain_input_data)
} else {
None
};
Expand Down
2 changes: 2 additions & 0 deletions src/graph/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -311,6 +311,8 @@ pub struct GraphSettings {
pub model_instance_shapes: Vec<Vec<usize>>,
/// model output scales
pub model_output_scales: Vec<u32>,
/// model input scales
pub model_input_scales: Vec<u32>,
/// the of instance cells used by modules
pub module_sizes: ModuleSizes,
/// required_lookups
Expand Down
1 change: 1 addition & 0 deletions src/graph/model.rs
Original file line number Diff line number Diff line change
Expand Up @@ -460,6 +460,7 @@ impl Model {
num_constraints,
required_lookups: lookup_ops,
model_output_scales: self.graph.get_output_scales(),
model_input_scales: self.graph.get_input_scales(),
total_const_size,
check_mode,
version: env!("CARGO_PKG_VERSION").to_string(),
Expand Down
4 changes: 2 additions & 2 deletions src/python.rs
Original file line number Diff line number Diff line change
Expand Up @@ -916,7 +916,7 @@ fn create_evm_data_attestation(
abi_path: PathBuf,
input_data: PathBuf,
) -> Result<bool, PyErr> {
crate::execute::create_evm_data_attestation_verifier(
crate::execute::create_evm_data_attestation(
vk_path,
srs_path,
settings_path,
Expand All @@ -925,7 +925,7 @@ fn create_evm_data_attestation(
input_data,
)
.map_err(|e| {
let err_str = format!("Failed to run create_evm_data_attestation_verifier: {}", e);
let err_str = format!("Failed to run create_evm_data_attestation: {}", e);
PyRuntimeError::new_err(err_str)
})?;

Expand Down
2 changes: 1 addition & 1 deletion tests/integration_tests.rs
Original file line number Diff line number Diff line change
Expand Up @@ -1815,7 +1815,7 @@ mod native_tests {
.status()
.expect("failed to execute process");
assert!(status.success());

let status = Command::new(format!("{}/release/ezkl", *CARGO_TARGET_DIR))
.args([
"gen-witness",
Expand Down
3 changes: 3 additions & 0 deletions tests/wasm/settings.json
Original file line number Diff line number Diff line change
Expand Up @@ -25,6 +25,9 @@
"model_output_scales": [
7
],
"model_input_scales": [
20
],
"model_instance_shapes": [
[
1,
Expand Down

0 comments on commit 5800eb5

Please sign in to comment.