Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Kits reloaded #1633

Open
wants to merge 19 commits into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion .travis/cli-tests.sh
Original file line number Diff line number Diff line change
Expand Up @@ -118,7 +118,7 @@ $TRACT_RUN $MODELS/mdl-en-2019-Q3-librispeech.onnx \
$CACHE_FILE hey_snips_v4_model17.pb
$TRACT_RUN $MODELS/hey_snips_v4_model17.pb \
-i S,20,f32 --pulse 8 dump --cost -q \
--assert-cost "FMA(F32)=2060448,Div(F32)=24576,Buffer(F32)=2920,Params(F32)=222250"
--assert-cost "FMA(F32)=2060448,Div(F32)=24576,Buffer(F32)=2920,Params(F32)=222251"

$TRACT_RUN $MODELS/hey_snips_v4_model17.pb -i S,20,f32 \
dump -q \
Expand Down
21 changes: 10 additions & 11 deletions core/src/model/fact.rs
Original file line number Diff line number Diff line change
Expand Up @@ -411,18 +411,17 @@ impl<'a> From<&'a Arc<Tensor>> for TypedFact {

impl fmt::Debug for TypedFact {
fn fmt(&self, fmt: &mut fmt::Formatter) -> fmt::Result {
match (self.konst.as_ref(), self.opaque_fact.as_ref()) {
(Some(ref k), None) => write!(fmt, "{k:?}"),
(Some(ref k), Some(opaque)) => write!(fmt, "{k:?} 🔍 {opaque:?}"),
(None, None) if self.rank() > 0 => {
write!(fmt, "{:?},{:?}", self.shape, self.datum_type)
}
(None, Some(ref opaque)) if self.rank() > 0 => {
write!(fmt, "{:?},{:?} 🔍 {opaque:?}", self.shape, self.datum_type)
write!(fmt, "{:?},{:?}", self.shape, self.datum_type)?;
if self.datum_type.is_opaque() {
if let Some(of) = &self.opaque_fact {
write!(fmt, " 🔍 {:?} ", of)?
} else {
write!(fmt, " 🔍 <no opaque fact> ")?
}
(None, Some(ref opaque)) => write!(fmt, "{:?} 🔍 {opaque:?}", self.datum_type),
(None, None) => write!(fmt, "{:?}", self.datum_type),
}?;
}
if let Some(k) = &self.konst {
write!(fmt, "🟰 {:?}", k)?
}
Ok(())
}
}
Expand Down
61 changes: 45 additions & 16 deletions core/src/model/patch.rs
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,6 @@ use tract_data::itertools::{izip, Itertools};

use crate::internal::*;
use crate::model::*;
use crate::ops::konst::Const;

/// A change to apply to a model.
///
Expand Down Expand Up @@ -104,7 +103,12 @@ where
pub fn tap_model(&mut self, model: &Graph<F, O>, outlet: OutletId) -> TractResult<OutletId> {
let fact = model.outlet_fact(outlet)?;
let id = self.add_source(
format!("tap.{}-{}/{}", model.node(outlet.node).name, outlet.node, outlet.slot),
format!(
"tap.{}-{}/{}",
model.node(outlet.node).name,
outlet.node,
outlet.slot
),
dyn_clone::clone(fact),
)?;
self.taps.insert(id, outlet);
Expand All @@ -119,7 +123,10 @@ where
model: &Graph<F, O>,
outlets: impl IntoIterator<Item = &'a OutletId>,
) -> TractResult<TVec<OutletId>> {
outlets.into_iter().map(|o| self.tap_model(model, *o)).collect::<TractResult<TVec<_>>>()
outlets
.into_iter()
.map(|o| self.tap_model(model, *o))
.collect::<TractResult<TVec<_>>>()
}

pub unsafe fn shunt_outside_unchecked(
Expand All @@ -141,7 +148,14 @@ where
let original_fact = model.outlet_fact(outlet)?;
let new_fact = self.model.outlet_fact(by)?;
if !original_fact.compatible_with(new_fact) {
bail!("Trying to substitute a {:?} by {:?}.\n{:?}", original_fact, new_fact, self);
bail!(
"Trying to substitute a {:?} by {:?} as output #{} of {}.\n{:?}",
original_fact,
new_fact,
outlet.slot,
model.node(outlet.node),
self
);
}
self.shunts.insert(outlet, by);
Ok(())
Expand Down Expand Up @@ -200,9 +214,11 @@ where
{
Ok(None)
} else {
Self::rewire(patched_model, &node.inputs, &[node.id.into()], &|_p, xs| Ok(xs.into()))
.with_context(|| format!("Shunting {node}"))
.map(Some)
Self::rewire(patched_model, &node.inputs, &[node.id.into()], &|_p, xs| {
Ok(xs.into())
})
.with_context(|| format!("Shunting {node}"))
.map(Some)
}
}

Expand Down Expand Up @@ -295,11 +311,13 @@ where
// this is a tap
continue;
}
if let Some(k) = node.op_as::<Const>() {
mapping.insert(node.id.into(), target.add_const(&node.name, k.0.clone())?);
continue;
}
let Node { id: patch_node_id, name, inputs, op, outputs } = node;
let Node {
id: patch_node_id,
name,
inputs,
op,
outputs,
} = node;
let n_outputs = outputs.len();
for dup in 0..target.nodes.len() {
if target.node(dup).op().same_as(op.as_ref())
Expand All @@ -319,7 +337,10 @@ where
let added_node_id = target.add_node(name, op, facts)?;
new_nodes.insert(added_node_id);
for ix in 0..n_outputs {
mapping.insert(OutletId::new(patch_node_id, ix), OutletId::new(added_node_id, ix));
mapping.insert(
OutletId::new(patch_node_id, ix),
OutletId::new(added_node_id, ix),
);
}
all_inputs.insert(added_node_id, inputs);
if <Graph<F, O>>::is_source(&target.node(added_node_id).op) {
Expand All @@ -335,7 +356,9 @@ where
debug_assert_eq!(target.output_outlets()?.len(), prior_target_outputs);
for (&outlet, &by) in shunt_outlet_by.iter().sorted() {
let replace_by = mapping[&by];
let succs = target.nodes()[outlet.node].outputs[outlet.slot].successors.clone();
let succs = target.nodes()[outlet.node].outputs[outlet.slot]
.successors
.clone();
for succ in succs {
target.add_edge(replace_by, succ)?;
}
Expand Down Expand Up @@ -368,14 +391,20 @@ where
maybe_garbage.remove(&maybe);
if !target.outputs.iter().any(|output| output.node == maybe)
&& !target.inputs.iter().any(|input| input.node == maybe)
&& target.node(maybe).outputs.iter().all(|of| of.successors.is_empty())
&& target
.node(maybe)
.outputs
.iter()
.all(|of| of.successors.is_empty())
{
target.node_mut(maybe).op = target.create_dummy();
target.node_mut(maybe).name = format!("Dummy-node-{}", maybe);
target.node_mut(maybe).outputs.clear(); // necessary to drop facts and consts
let inputs = std::mem::take(&mut target.node_mut(maybe).inputs);
for &i in &inputs {
target.node_mut(i.node).outputs[i.slot].successors.retain(|s| s.node != maybe);
target.node_mut(i.node).outputs[i.slot]
.successors
.retain(|s| s.node != maybe);
maybe_garbage.insert(i.node);
}
target.check_edges()?;
Expand Down
48 changes: 34 additions & 14 deletions core/src/model/typed.rs
Original file line number Diff line number Diff line change
Expand Up @@ -2,11 +2,11 @@ use crate::internal::*;
use crate::model::*;
use crate::ops;
use crate::ops::konst::Const;
use crate::ops::matmul::de_block_quant::BlockQuantValue;
use crate::optim::OptimizerSession;
use crate::plan::{FrozenSimpleState, SimplePlan, SimpleState};
use crate::transform::ModelTransform;
use tract_data::TooEarly;
use tract_linalg::frame::block_quant::BlockQuantValue;
use tract_num_traits::Zero;

/// A model with completely determined types and shapes.
Expand All @@ -29,7 +29,9 @@ pub type RunnableModel<F, O, M> = SimplePlan<F, O, M>;

impl SpecialOps<TypedFact, Box<dyn TypedOp>> for TypedModel {
fn is_source(op: &Box<dyn TypedOp>) -> bool {
op.as_op().downcast_ref::<ops::source::TypedSource>().is_some()
op.as_op()
.downcast_ref::<ops::source::TypedSource>()
.is_some()
}

fn create_dummy(&self) -> Box<dyn TypedOp> {
Expand All @@ -49,9 +51,10 @@ impl SpecialOps<TypedFact, Box<dyn TypedOp>> for TypedModel {
let op = op.into();
let name = name.into();
if let Some(konst) = op.downcast_ref::<Const>() {
// only if no opaque fact is present.
if konst.1.is_none() {
return Ok(tvec![self.add_const(name, konst.0.clone())?]);
for node in &self.nodes {
if node.op_as::<Const>().is_some_and(|other| other == konst) {
return Ok(tvec!(node.id.into()));
}
}
}
if self.nodes.iter().any(|n| n.name == name) {
Expand Down Expand Up @@ -80,8 +83,11 @@ impl SpecialOps<TypedFact, Box<dyn TypedOp>> for TypedModel {
.into_iter()
.enumerate()
.map(|(ix, o)| {
let name =
if ix == 0 { name.clone() } else { format!("{name}.{ix}") };
let name = if ix == 0 {
name.clone()
} else {
format!("{name}.{ix}")
};
self.add_const(name, o)
})
.collect::<TractResult<TVec<OutletId>>>();
Expand Down Expand Up @@ -145,7 +151,8 @@ impl SpecialOps<TypedFact, Box<dyn TypedOp>> for TypedModel {
.map(|id| id.into());
}
}
self.add_node(name, crate::ops::konst::Const::new(v), tvec!(fact)).map(|id| id.into())
self.add_node(name, crate::ops::konst::Const::new(v), tvec!(fact))
.map(|id| id.into())
}
}

Expand Down Expand Up @@ -202,7 +209,11 @@ impl TypedModel {
for node in &self.nodes {
for (ix, output) in node.outputs.iter().enumerate() {
output.fact.consistent().with_context(|| {
format!("Inconsistent fact {:?}: {:?}", OutletId::new(node.id, ix), output.fact)
format!(
"Inconsistent fact {:?}: {:?}",
OutletId::new(node.id, ix),
output.fact
)
})?
}
}
Expand All @@ -222,7 +233,9 @@ impl TypedModel {

/// Perform declutter passes on the network.
pub fn declutter(&mut self) -> TractResult<()> {
crate::optim::Optimizer::declutter().session().optimize(self)
crate::optim::Optimizer::declutter()
.session()
.optimize(self)
}

/// Perform optimization passes on the model, using a given optimizer session.
Expand Down Expand Up @@ -260,9 +273,14 @@ impl TypedModel {
&& inputs.iter().all(|i| i.konst.is_some())
&& outputs.iter().any(|o| o.konst.is_none())
{
let inputs_ref =
inputs.iter().map(|f| f.konst.clone().unwrap().into_tvalue()).collect();
match node.op.eval_with_session(&SessionState::default(), inputs_ref) {
let inputs_ref = inputs
.iter()
.map(|f| f.konst.clone().unwrap().into_tvalue())
.collect();
match node
.op
.eval_with_session(&SessionState::default(), inputs_ref)
{
Ok(res) => {
drop(inputs);
drop(outputs);
Expand Down Expand Up @@ -294,7 +312,9 @@ impl Translate<TypedFact, Box<dyn TypedOp>, TypedFact, Box<dyn TypedOp>> for Sym
mapping: &HashMap<OutletId, OutletId>,
) -> TractResult<TVec<OutletId>> {
target.check_consistency()?;
let outlets = node.op.concretize_dims(source, node, target, mapping, self)?;
let outlets = node
.op
.concretize_dims(source, node, target, mapping, self)?;
for &outlet in &outlets {
let fact = &mut target.nodes[outlet.node].outputs[outlet.slot].fact;
if fact.shape.volume().is_zero() {
Expand Down
27 changes: 17 additions & 10 deletions core/src/ops/array/gather.rs
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
use crate::internal::*;
use crate::ops::einsum::block_quant_aware_input_shape;
use crate::ops::matmul::de_block_quant::BlockQuantValue;
use ndarray::*;
use tract_linalg::frame::block_quant::BlockQuantValue;

#[derive(Debug, Clone, new, Hash)]
pub struct Gather {
Expand Down Expand Up @@ -41,7 +41,11 @@ impl Gather {
let mut icoords: TVec<usize> = ocoords[0..self.axis].into();
let kcoords = &ocoords[self.axis..][..indices.ndim()];
let k = indices[kcoords];
let k = if k < 0 { k + data_view.shape()[self.axis] as i64 } else { k } as usize;
let k = if k < 0 {
k + data_view.shape()[self.axis] as i64
} else {
k
} as usize;
icoords.push(k);
icoords.extend(ocoords[self.axis + indices.ndim()..].iter().copied());
output_view[ocoords] = data_view.get(&*icoords).context("Invalid gather")?.clone();
Expand All @@ -52,8 +56,8 @@ impl Gather {

fn eval_bq_to_f16(&self, data: &BlockQuantValue, indices: &TValue) -> TractResult<Tensor> {
ensure!(self.axis == 0);
ensure!(data.fact.shape.rank() == 2);
let data_shape = data.fact.shape.as_concrete().unwrap();
ensure!(data.fact.shape.len() == 2);
let data_shape = &data.fact.shape;
let output_shape = &*self.compute_output_shape(data_shape, indices.shape())?;
let mut output = unsafe { Tensor::uninitialized::<f16>(output_shape)? };
let indices_slice = indices.as_slice::<i64>()?;
Expand All @@ -77,11 +81,13 @@ impl TypedOp for Gather {
ensure!(inputs[1].datum_type == i64::datum_type());
if inputs[0].datum_type.is_opaque() {
let data_shape = block_quant_aware_input_shape(inputs[0])?;
Ok(tvec!(f16::fact(&*self.compute_output_shape(&data_shape, &inputs[1].shape)?)))
Ok(tvec!(f16::fact(
&*self.compute_output_shape(&data_shape, &inputs[1].shape)?
)))
} else {
Ok(tvec!(inputs[0]
.datum_type
.fact(&*self.compute_output_shape(&inputs[0].shape, &inputs[1].shape)?)))
Ok(tvec!(inputs[0].datum_type.fact(
&*self.compute_output_shape(&inputs[0].shape, &inputs[1].shape)?
)))
}
}

Expand Down Expand Up @@ -149,8 +155,9 @@ mod tests {
let gatherer = Gather::new(0);
for idx in 2..3 {
let index = Tensor::from(arr0(idx));
let outputs =
gatherer.eval(tvec![data.clone().into_tvalue(), index.into_tvalue()]).unwrap();
let outputs = gatherer
.eval(tvec![data.clone().into_tvalue(), index.into_tvalue()])
.unwrap();
let output = &outputs[0];
assert_eq!(output.shape().len(), 0);
assert_eq!(*output.to_scalar::<i64>().unwrap(), idx + 1);
Expand Down
Loading
Loading