Skip to content

Commit

Permalink
feat: private GatherElements op (#462)
Browse files Browse the repository at this point in the history
  • Loading branch information
alexander-camuto authored Sep 6, 2023
1 parent 3c114e3 commit 2229625
Show file tree
Hide file tree
Showing 16 changed files with 750 additions and 157 deletions.
2 changes: 2 additions & 0 deletions .github/workflows/rust.yml
Original file line number Diff line number Diff line change
Expand Up @@ -458,6 +458,8 @@ jobs:
# # now dump the contents of the file into a file called kaggle.json
# echo $KAGGLE_API_KEY > /home/ubuntu/.kaggle/kaggle.json
# chmod 600 /home/ubuntu/.kaggle/kaggle.json
- name: XGBoost
run: source .env/bin/activate; cargo nextest run py_tests::tests::run_notebook_::tests_16_expects
- name: Gradient boosted tress
run: source .env/bin/activate; cargo nextest run py_tests::tests::run_notebook_::tests_15_expects
- name: Random Forest
Expand Down
337 changes: 337 additions & 0 deletions examples/notebooks/xgboost.ipynb

Large diffs are not rendered by default.

65 changes: 65 additions & 0 deletions examples/onnx/xgboost/gen.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,65 @@
# make sure you have the dependencies required here already installed
import json
import numpy as np
from sklearn.datasets import load_iris
from sklearn.model_selection import train_test_split
from xgboost import XGBClassifier as Gbc
import torch
import ezkl
import os
from torch import nn
import xgboost as xgb
from hummingbird.ml import convert

NUM_CLASSES = 3

iris = load_iris()
X, y = iris.data, iris.target
X = X.astype(np.float32)
X_train, X_test, y_train, y_test = train_test_split(X, y)
clr = Gbc(n_estimators=12)
clr.fit(X_train, y_train)

# convert to torch


torch_gbt = convert(clr, 'torch')

print(torch_gbt)
# assert predictions from torch are = to sklearn
diffs = []

for i in range(len(X_test)):
torch_pred = torch_gbt.predict(torch.tensor(X_test[i].reshape(1, -1)))
sk_pred = clr.predict(X_test[i].reshape(1, -1))
diffs.append(torch_pred != sk_pred[0])

print("num diff: ", sum(diffs))


# Input to the model
shape = X_train.shape[1:]
x = torch.rand(1, *shape, requires_grad=False)
torch_out = torch_gbt.predict(x)
# Export the model
torch.onnx.export(torch_gbt.model, # model being run
# model input (or a tuple for multiple inputs)
x,
# where to save the model (can be a file or file-like object)
"network.onnx",
export_params=True, # store the trained parameter weights inside the model file
opset_version=11, # the ONNX version to export the model to
do_constant_folding=True, # whether to execute constant folding for optimization
input_names=['input'], # the model's input names
output_names=['output'], # the model's output names
dynamic_axes={'input': {0: 'batch_size'}, # variable length axes
'output': {0: 'batch_size'}})

d = ((x).detach().numpy()).reshape([-1]).tolist()

data = dict(input_shapes=[shape],
input_data=[d],
output_data=[(o).reshape([-1]).tolist() for o in torch_out])

# Serialize data into file:
json.dump(data, open("input.json", 'w'))
1 change: 1 addition & 0 deletions examples/onnx/xgboost/input.json
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
{"input_shapes": [[4]], "input_data": [[0.4291560649871826, 0.007834196090698242, 0.1443231701850891, 0.6625866293907166]], "output_data": [[0]]}
Binary file added examples/onnx/xgboost/network.onnx
Binary file not shown.
35 changes: 34 additions & 1 deletion src/circuit/ops/hybrid.rs
Original file line number Diff line number Diff line change
Expand Up @@ -41,6 +41,10 @@ pub enum HybridOp {
dim: usize,
constant_idx: Option<Tensor<usize>>,
},
GatherElements {
dim: usize,
constant_idx: Option<Tensor<usize>>,
},
}

impl<F: PrimeField + TensorType + PartialOrd> Op<F> for HybridOp {
Expand Down Expand Up @@ -117,6 +121,19 @@ impl<F: PrimeField + TensorType + PartialOrd> Op<F> for HybridOp {
(res.clone(), inter_equals)
}
}
HybridOp::GatherElements { dim, constant_idx } => {
if let Some(idx) = constant_idx {
let res = tensor::ops::gather_elements(&x, &idx, *dim)?;
(res.clone(), vec![])
} else {
let y = inputs[1].clone().map(|x| felt_to_i128(x));
let inter_equals: Vec<Tensor<i128>> =
vec![Tensor::from(0..x.dims()[*dim] as i128)];
let res = tensor::ops::gather_elements(&x, &y.map(|x| x as usize), *dim)?;
(res.clone(), inter_equals)
}
}

HybridOp::MaxPool2d {
padding,
stride,
Expand Down Expand Up @@ -167,6 +184,7 @@ impl<F: PrimeField + TensorType + PartialOrd> Op<F> for HybridOp {
HybridOp::Less { .. } => "LESS",
HybridOp::Equals => "EQUALS",
HybridOp::Gather { .. } => "GATHER",
HybridOp::GatherElements { .. } => "GATHERELEMENTS",
};
name.into()
}
Expand All @@ -186,6 +204,13 @@ impl<F: PrimeField + TensorType + PartialOrd> Op<F> for HybridOp {
layouts::gather(config, region, values[..].try_into()?, *dim)?
}
}
HybridOp::GatherElements { dim, constant_idx } => {
if let Some(idx) = constant_idx {
tensor::ops::gather_elements(&values[0].get_inner_tensor()?, idx, *dim)?.into()
} else {
layouts::gather_elements(config, region, values[..].try_into()?, *dim)?
}
}
HybridOp::MaxPool2d {
padding,
stride,
Expand Down Expand Up @@ -226,6 +251,13 @@ impl<F: PrimeField + TensorType + PartialOrd> Op<F> for HybridOp {
}))
}

fn requires_specific_input_scales(&self) -> Vec<(usize, u32)> {
match self {
HybridOp::Gather { .. } | HybridOp::GatherElements { .. } => vec![(1, 0)],
_ => vec![],
}
}

fn out_scale(&self, in_scales: Vec<u32>) -> u32 {
match self {
HybridOp::Greater { .. }
Expand Down Expand Up @@ -269,7 +301,8 @@ impl<F: PrimeField + TensorType + PartialOrd> Op<F> for HybridOp {
HybridOp::Greater { .. }
| HybridOp::Less { .. }
| HybridOp::Equals
| HybridOp::Gather { .. } => {
| HybridOp::Gather { .. }
| HybridOp::GatherElements { .. } => {
vec![LookupOp::GreaterThan {
a: circuit::utils::F32(0.),
}]
Expand Down
84 changes: 84 additions & 0 deletions src/circuit/ops/layouts.rs
Original file line number Diff line number Diff line change
Expand Up @@ -543,6 +543,90 @@ pub fn gather<F: PrimeField + TensorType + PartialOrd>(
Ok(output.into())
}

/// Gather accumulated layout
pub fn gather_elements<F: PrimeField + TensorType + PartialOrd>(
config: &BaseConfig<F>,
region: &mut RegionCtx<F>,
values: &[ValTensor<F>; 2],
dim: usize,
) -> Result<ValTensor<F>, Box<dyn Error>> {
let (mut input, mut index) = (values[0].clone(), values[1].clone());

assert_eq!(input.dims().len(), index.dims().len());

if !input.all_prev_assigned() {
input = region.assign(&config.inputs[0], &input)?;
}
if !index.all_prev_assigned() {
index = region.assign(&config.inputs[1], &index)?;
}

region.increment(std::cmp::max(input.len(), index.len()));

// Calculate the output tensor size
let input_dim = input.dims()[dim];
let output_size = index.dims().to_vec();

// Allocate memory for the output tensor
let cartesian_coord = output_size
.iter()
.map(|x| 0..*x)
.multi_cartesian_product()
.collect::<Vec<_>>();

let output: Result<Vec<ValType<F>>, Box<dyn Error>> = cartesian_coord
.iter()
.map(|coord| {
let index_val = index.get_inner_tensor()?.get(coord);

let mut slice = coord.iter().map(|x| *x..*x + 1).collect::<Vec<_>>();
slice[dim] = 0..input_dim;

let mut sliced_input = input.get_slice(&slice)?;
sliced_input.flatten();

let index_valtensor: ValTensor<F> =
Tensor::from([index_val.clone()].into_iter()).into();

let res =
select(config, region, &[sliced_input, index_valtensor])?.get_inner_tensor()?;

Ok(res[0].clone())
})
.collect();

let output = output?;

let mut output: ValTensor<F> = Tensor::new(Some(&output), &[output.len()])?.into();
// Reshape the output tensor
output.reshape(&output_size)?;

if matches!(&config.check_mode, CheckMode::SAFE) {
// during key generation this will be unknown vals so we use this as a flag to check
// TODO: this isn't very safe and would be better to get the phase directly
let mut is_assigned = !output.any_unknowns();
for val in values.iter() {
is_assigned = is_assigned && !val.any_unknowns();
}
if is_assigned {
let mut x = values[0].get_int_evals()?;
x.reshape(&input.dims());
let mut ind = values[1].get_int_evals()?;
ind.reshape(&index.dims());

let ref_gather: Tensor<i128> =
tensor::ops::gather_elements(&x, &ind.map(|x| x as usize), dim)?;

let mut output_evals = output.get_int_evals()?;
output_evals.reshape(output.dims());

assert_eq!(output_evals, ref_gather)
}
};

Ok(output.into())
}

/// Sum accumulated layout
pub fn sum<F: PrimeField + TensorType + PartialOrd>(
config: &BaseConfig<F>,
Expand Down
5 changes: 5 additions & 0 deletions src/circuit/ops/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -52,6 +52,11 @@ pub trait Op<F: PrimeField + TensorType + PartialOrd>: std::fmt::Debug + Send +
vec![]
}

/// Do any of the inputs to this op require specific input scales?
fn requires_specific_input_scales(&self) -> Vec<(usize, u32)> {
vec![]
}

/// Returns the lookups required by the operation.
fn required_lookups(&self) -> Vec<LookupOp> {
vec![]
Expand Down
9 changes: 5 additions & 4 deletions src/graph/model.rs
Original file line number Diff line number Diff line change
Expand Up @@ -307,7 +307,8 @@ impl NodeType {
#[derive(Clone, Debug, Default, Serialize, Deserialize, PartialEq)]
/// A set of EZKL nodes that represent a computational graph.
pub struct ParsedNodes {
nodes: BTreeMap<usize, NodeType>,
/// The nodes in the graph.
pub nodes: BTreeMap<usize, NodeType>,
inputs: Vec<usize>,
outputs: Vec<Outlet>,
}
Expand Down Expand Up @@ -925,14 +926,14 @@ impl Model {
}
}
}
Self::clean_useless_consts(&mut nodes);
Self::empty_raw_const_value(&mut nodes);

Ok(nodes)
}

#[cfg(not(target_arch = "wasm32"))]
/// Removes all nodes that are consts with 0 uses
fn clean_useless_consts(nodes: &mut BTreeMap<usize, NodeType>) {
fn empty_raw_const_value(nodes: &mut BTreeMap<usize, NodeType>) {
// remove all nodes that are consts with 0 uses now
nodes.retain(|_, n| match n {
NodeType::Node(n) => match &mut n.opkind {
Expand All @@ -943,7 +944,7 @@ impl Model {
_ => true,
},
NodeType::SubGraph { model, .. } => {
Self::clean_useless_consts(&mut model.graph.nodes);
Self::empty_raw_const_value(&mut model.graph.nodes);
true
}
});
Expand Down
49 changes: 46 additions & 3 deletions src/graph/node.rs
Original file line number Diff line number Diff line change
Expand Up @@ -273,7 +273,7 @@ impl SupportedOp {
}

#[cfg(not(target_arch = "wasm32"))]
fn rescale(&self, in_scales: Vec<u32>) -> Box<dyn Op<Fp>> {
fn homogenous_rescale(&self, in_scales: Vec<u32>) -> Box<dyn Op<Fp>> {
let inputs_to_scale = self.requires_homogenous_input_scales();
// creates a rescaled op if the inputs are not homogenous
let op = self.clone_dyn();
Expand Down Expand Up @@ -391,6 +391,19 @@ impl Op<Fp> for SupportedOp {
}
}

fn requires_specific_input_scales(&self) -> Vec<(usize, u32)> {
match self {
SupportedOp::Linear(op) => Op::<Fp>::requires_specific_input_scales(op),
SupportedOp::Nonlinear(op) => Op::<Fp>::requires_specific_input_scales(op),
SupportedOp::Hybrid(op) => Op::<Fp>::requires_specific_input_scales(op),
SupportedOp::Input(op) => Op::<Fp>::requires_specific_input_scales(op),
SupportedOp::Constant(op) => Op::<Fp>::requires_specific_input_scales(op),
SupportedOp::Unknown(op) => Op::<Fp>::requires_specific_input_scales(op),
SupportedOp::Rescaled(op) => Op::<Fp>::requires_specific_input_scales(op),
SupportedOp::RebaseScale(op) => Op::<Fp>::requires_specific_input_scales(op),
}
}

fn clone_dyn(&self) -> Box<dyn Op<Fp>> {
match self {
SupportedOp::Linear(op) => Box::new(op.clone()),
Expand Down Expand Up @@ -580,7 +593,10 @@ impl Node {

let homogenous_inputs = opkind.requires_homogenous_input_scales();
// autoamtically increases a constant's scale if it is only used once and
for input in homogenous_inputs {
for input in homogenous_inputs
.into_iter()
.filter(|i| !deleted_indices.contains(i))
{
let input_node = other_nodes.get_mut(&inputs[input].idx()).unwrap();
let input_opkind = &mut input_node.opkind();
if let Some(constant) = input_opkind.get_mutable_constant() {
Expand All @@ -592,7 +608,34 @@ impl Node {
}
}

opkind = opkind.rescale(in_scales.clone()).into();
let inputs_at_specific_scales = opkind.requires_specific_input_scales();
// rescale the inputs if necessary to get consistent fixed points
for (input, scale) in inputs_at_specific_scales
.into_iter()
.filter(|(i, _)| !deleted_indices.contains(i))
{
let input_node = other_nodes.get_mut(&inputs[input].idx()).unwrap();
let input_opkind = &mut input_node.opkind();
if let Some(constant) = input_opkind.get_mutable_constant() {
rescale_const_with_single_use(constant, in_scales.clone(), param_visibility)?;
input_node.replace_opkind(constant.clone_dyn().into());
let out_scale = input_opkind.out_scale(vec![]);
input_node.bump_scale(out_scale);
in_scales[input] = out_scale;
} else {
let scale_diff = in_scales[input] as i128 - scale as i128;
let rebased = if scale_diff > 0 {
RebaseScale::rebase(input_opkind.clone(), scale, in_scales[input], 1)
} else {
RebaseScale::rebase_up(input_opkind.clone(), scale, in_scales[input])
};
input_node.replace_opkind(rebased.into());
input_node.bump_scale(scale);
in_scales[input] = scale;
}
}

opkind = opkind.homogenous_rescale(in_scales.clone()).into();
let mut out_scale = opkind.out_scale(in_scales.clone());
opkind = RebaseScale::rebase(opkind, scales.input, out_scale, scales.rebase_multiplier);
out_scale = opkind.out_scale(in_scales);
Expand Down
Loading

0 comments on commit 2229625

Please sign in to comment.