Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Kits reloaded #1633

Open
wants to merge 19 commits into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Prev Previous commit
Next Next commit
sever dep from BQF to ShapeFact
  • Loading branch information
kali committed Jan 30, 2025
commit 20b2fc655ecb2ad6ced8f32b6f63a92e3c17f091
4 changes: 2 additions & 2 deletions core/src/ops/array/gather.rs
Original file line number Diff line number Diff line change
Expand Up @@ -52,8 +52,8 @@ impl Gather {

fn eval_bq_to_f16(&self, data: &BlockQuantValue, indices: &TValue) -> TractResult<Tensor> {
ensure!(self.axis == 0);
ensure!(data.fact.shape.rank() == 2);
let data_shape = data.fact.shape.as_concrete().unwrap();
ensure!(data.fact.shape.len() == 2);
let data_shape = &data.fact.shape;
let output_shape = &*self.compute_output_shape(data_shape, indices.shape())?;
let mut output = unsafe { Tensor::uninitialized::<f16>(output_shape)? };
let indices_slice = indices.as_slice::<i64>()?;
Expand Down
26 changes: 13 additions & 13 deletions core/src/ops/einsum/as_matmul.rs
Original file line number Diff line number Diff line change
Expand Up @@ -252,21 +252,21 @@ impl TypedOp for BasicMatMul {
.quantize_output
.unwrap_or(a.datum_type)
.fact(self.output_shape(&a.shape, &b.shape))))
} else if let Some(opf) =
inputs[0].opaque_fact.as_ref().and_then(|of| of.downcast_ref::<BlockQuantFact>())
{
let a_shape: ShapeFact = a.shape.iter().chain(opf.shape.iter()).collect();
Ok(tvec!(self
.quantize_output
.unwrap_or(b.datum_type)
.fact(self.output_shape(&a_shape, &b.shape))))
} else if let Some(bqv) = inputs[0]
.konst
} else if let Some(opf) = inputs[0]
.opaque_fact
.as_ref()
.and_then(|k| k.to_scalar::<Opaque>().ok())
.and_then(|o| o.downcast_ref::<BlockQuantValue>())
.and_then(|of| of.downcast_ref::<BlockQuantFact>())
.or_else(|| {
inputs[0]
.konst
.as_ref()
.and_then(|k| k.to_scalar::<Opaque>().ok())
.and_then(|o| o.downcast_ref::<BlockQuantValue>())
.map(|v| &v.fact)
})
{
let a_shape: ShapeFact = a.shape.iter().chain(bqv.fact.shape.iter()).collect();
let a_shape: ShapeFact =
a.shape.iter().cloned().chain(opf.shape.iter().map(|d| d.to_dim())).collect();
Ok(tvec!(self
.quantize_output
.unwrap_or(b.datum_type)
Expand Down
9 changes: 3 additions & 6 deletions core/src/ops/einsum/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -31,12 +31,9 @@ pub fn block_quant_aware_input_shape(fact: &TypedFact) -> TractResult<Cow<[TDim]
let Some(bqf) = opaque_fact.downcast_ref::<BlockQuantFact>() else {
bail!("Datum fact is opaque, but no opaque fact was found.")
};
if bqf.shape.rank() == 0 {
Ok(Cow::Borrowed(&*bqf.shape))
} else {
let shape: Vec<TDim> = fact.shape.iter().chain(bqf.shape.iter()).cloned().collect();
Ok(Cow::Owned(shape))
}
let shape: Vec<TDim> =
fact.shape.iter().cloned().chain(bqf.shape.iter().map(|d| d.to_dim())).collect();
Ok(Cow::Owned(shape))
}

#[derive(Clone, Hash)]
Expand Down
16 changes: 8 additions & 8 deletions core/src/ops/matmul/de_block_quant.rs
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,7 @@ use crate::transform::ModelTransform;
#[derive(Clone, Hash)]
pub struct BlockQuantFact {
pub format: Box<dyn BlockQuant>,
pub shape: ShapeFact,
pub shape: TVec<usize>,
}

impl std::fmt::Debug for BlockQuantFact {
Expand All @@ -19,9 +19,9 @@ impl std::fmt::Debug for BlockQuantFact {
}

impl OpaqueFact for BlockQuantFact {

fn mem_size(&self) -> TDim {
self.shape.volume() * self.format.block_bytes()
(self.shape.iter().product::<usize>() / self.format.block_len() * self.format.block_bytes())
.to_dim()
}
}

Expand Down Expand Up @@ -84,9 +84,9 @@ fn block_quant_einsum_weights(
if a.konst.is_none() || a.rank() != 2 {
return Ok(None);
}
let a: &Tensor = a.konst.as_ref().unwrap();
let AxesOrPatch::Annotated(op) = ensure_mkn_axes(op, model, node)? else { return Ok(None) };
if op.a_m() == 1 && op.a_k() == 0 {
let a: &Tensor = a.konst.as_ref().unwrap();
let mut patch = TypedModelPatch::default();
let konst =
patch.add_const(&model.node(node.inputs[0].node).name, a.clone().move_axis(1, 0)?)?;
Expand All @@ -103,13 +103,13 @@ fn block_quant_einsum_weights(
}
let format = Q4_0;
let mut patch = TypedModelPatch::default();
let weights = if a.datum_type == f16::datum_type() {
format.quant_f16(a.konst.as_ref().unwrap().as_slice::<f16>()?)?
let weights = if a.datum_type() == f16::datum_type() {
format.quant_f16(a.as_slice::<f16>()?)?
} else {
format.quant_f32(a.konst.as_ref().unwrap().cast_to::<f32>()?.as_slice::<f32>()?)?
format.quant_f32(a.cast_to::<f32>()?.as_slice::<f32>()?)?
};
let name = &model.node(node.inputs[0].node).name;
let fact = BlockQuantFact { format: Box::new(format), shape: a.shape.clone() };
let fact = BlockQuantFact { format: Box::new(format), shape: a.shape().into() };
let value = BlockQuantValue { fact: fact.clone(), value: weights };
let weights = patch.wire_node(
format!("{name}.bq"),
Expand Down
2 changes: 1 addition & 1 deletion nnef/src/ser.rs
Original file line number Diff line number Diff line change
Expand Up @@ -403,7 +403,7 @@ impl<'a> IntoAst<'a> {
let id = self.scoped_id(&name);
let shape = if tensor.datum_type().is_opaque() {
if let Some(bqv) = tensor.to_scalar::<Opaque>()?.downcast_ref::<BlockQuantValue>() {
bqv.fact.shape.as_concrete().unwrap()
&bqv.fact.shape
} else {
bail!("Unexpected opaque tensor in serialization {tensor:?}");
}
Expand Down