Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Kits reloaded #1633

Open
wants to merge 19 commits into
base: main
Choose a base branch
from
Open
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Prev Previous commit
Next Next commit
wip codegen consts
  • Loading branch information
kali committed Jan 30, 2025
commit 0c11922c06d3d23fd5dada7039205676725c8702
65 changes: 65 additions & 0 deletions core/src/ops/konst.rs
Original file line number Diff line number Diff line change
@@ -1,4 +1,9 @@
use tract_itertools::Itertools;

use crate::internal::*;
use crate::ops::array::Gather;
use crate::ops::einsum::EinSum;
use crate::ops::matmul::de_block_quant::BlockQuantValue;

#[derive(Debug, Clone, Hash, Eq, PartialEq)]
pub struct Const(pub Arc<Tensor>, pub Option<Box<dyn OpaqueFact>>);
Expand Down Expand Up @@ -86,4 +91,64 @@ impl TypedOp for Const {
};
target.wire_node(&node.name, op, &[])
}

fn codegen(
&self,
model: &TypedModel,
node: &TypedNode,
) -> TractResult<Option<TypedModelPatch>> {
println!("{node}");
let looks_like_weights = (self.0.datum_type().is_number() && self.0.rank() == 2)
|| (self.0.to_scalar::<Opaque>().is_ok_and(|opaque| opaque.is::<BlockQuantValue>()));
if !looks_like_weights {
return Ok(None);
}
let mut have_abstract_einsum = false;
for succ in &node.outputs[0].successors {
let snode = model.node(succ.node);
println!(" * {snode} => {}", snode.op.info()?.join(" :: "));
if let Some(gather) = snode.op_as::<Gather>() {
if succ.slot != 0 || gather.axis != 0 {
return Ok(None);
}
} else if let Some(einsum) = snode.op_as::<EinSum>() {
if succ.slot != 0 || snode.inputs.len() != 2 {
return Ok(None);
}
let m_axis = einsum.axes.axis((InOut::In(0), 0))?;
if m_axis.inputs[0].len() != 1
|| m_axis.inputs[1].len() != 0
|| m_axis.outputs[0].len() != 1
{
println!("Failed m_axis check");
return Ok(None);
}
let k_axis = einsum.axes.axis((InOut::In(0), 1))?;
if k_axis.inputs[0].len() != 1
|| k_axis.inputs[1].len() != 1
|| k_axis.outputs[0].len() != 0
{
println!("Failed k_axis check");
return Ok(None);
}
for axis in einsum.axes.iter_all_axes() {
if axis != k_axis
&& axis != m_axis
&& axis.inputs[0].len() == 0
&& axis.inputs[1].len() == 1
&& axis.outputs[0].len() == 1
&& snode.outputs[0].fact.shape[axis.outputs[0][0]].as_i64().is_none()
{
have_abstract_einsum = true;
}
}
} else {
return Ok(None);
}
}
if node.outputs[0].successors.len() > 1 || have_abstract_einsum {
println!("Operate!");
}
Ok(None)
}
}