Skip to content

Commit

Permalink
Refactor repr quantized tensor handle
Browse files Browse the repository at this point in the history
  • Loading branch information
laggui committed Oct 15, 2024
1 parent 0452d2c commit 19199af
Show file tree
Hide file tree
Showing 6 changed files with 97 additions and 92 deletions.
6 changes: 3 additions & 3 deletions crates/burn-fusion/src/backend.rs
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@ use crate::{
use burn_tensor::{
backend::{Backend, DeviceOps, SyncType},
ops::{BoolTensor, FloatTensor, IntTensor, QuantizedTensor},
repr::{OperationDescription, ReprBackend, TensorHandle},
repr::{OperationDescription, QuantizedKind, ReprBackend, TensorHandle},
Device,
};
use serde::{de::DeserializeOwned, Serialize};
Expand Down Expand Up @@ -184,7 +184,7 @@ impl<B: FusionBackend> ReprBackend for Fusion<B> {
}

fn quantized_tensor(
_handles: Vec<TensorHandle<Self::Handle>>,
_handles: QuantizedKind<TensorHandle<Self::Handle>>,
_scheme: burn_tensor::quantization::QuantizationScheme,
) -> QuantizedTensor<Self> {
todo!() // not as simple
Expand All @@ -202,7 +202,7 @@ impl<B: FusionBackend> ReprBackend for Fusion<B> {
tensor
}

fn quantized_tensor_handle(_tensor: QuantizedTensor<Self>) -> Vec<Self::Handle> {
fn quantized_tensor_handle(_tensor: QuantizedTensor<Self>) -> QuantizedKind<Self::Handle> {
todo!() // not as simple
}
}
52 changes: 26 additions & 26 deletions crates/burn-fusion/src/ops/qtensor.rs
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@ use burn_tensor::{
repr::{
DequantizeOperationDescription, FloatOperationDescription, HandleContainer,
OperationDescription, QuantizationParametersDescription, QuantizeOperationDescription,
QuantizedKind,
},
DType, Device, Element, Shape, TensorData,
};
Expand All @@ -25,19 +26,17 @@ impl<B: FusionBackend> QTensorOps<Self> for Fusion<B> {
let tensor = B::q_from_data(data, device);
let shape = B::q_shape(&tensor);

let mut handles = B::quantized_tensor_handle(tensor);
let handles = B::quantized_tensor_handle(tensor);
let qparams = match strategy {
QuantizationStrategy::PerTensorAffineInt8(_) => {
let num_handles = handles.len();
assert_eq!(
num_handles, 3,
"Expected 3 handles for quantized tensor, got {num_handles}"
);
let offset = handles.pop().unwrap();
let scale = handles.pop().unwrap();
let offset = if let Some(offset) = handles.offset {
offset
} else {
panic!("Expected offset for quantized tensor.");
};
FusionQuantizationParameters {
scale: client.register_tensor(
scale,
handles.scale,
vec![1],
StreamId::current(),
B::FloatElem::dtype(),
Expand All @@ -51,15 +50,13 @@ impl<B: FusionBackend> QTensorOps<Self> for Fusion<B> {
}
}
QuantizationStrategy::PerTensorSymmetricInt8(_) => {
let num_handles = handles.len();
assert_eq!(
num_handles, 2,
"Expected 2 handles for quantized tensor, got {num_handles}"
assert!(
handles.offset.is_none(),
"Offset should not be provided for symmetric quantization."
);
let scale = handles.pop().unwrap();
FusionQuantizationParameters {
scale: client.register_tensor(
scale,
handles.scale,
vec![1],
StreamId::current(),
B::FloatElem::dtype(),
Expand All @@ -69,7 +66,7 @@ impl<B: FusionBackend> QTensorOps<Self> for Fusion<B> {
}
};
let qtensor = client.register_tensor(
handles.pop().unwrap(),
handles.tensor,
shape.dims,
StreamId::current(),
B::QuantizedEncoding::dtype(),
Expand Down Expand Up @@ -111,17 +108,20 @@ impl<B: FusionBackend> QTensorOps<Self> for Fusion<B> {

let qparams = QuantizationParametersPrimitive { scale, offset };
let output = B::quantize(tensor, &self.desc.scheme, qparams);
if let Some(offset) = &self.desc.qparams.offset {
handles.register_quantized_tensor::<B>(
&[&self.desc.out.id, &self.desc.qparams.scale.id, &offset.id],
output,
);
let q_ids = if let Some(offset) = &self.desc.qparams.offset {
QuantizedKind {
tensor: self.desc.out.id,
scale: self.desc.qparams.scale.id,
offset: Some(offset.id),
}
} else {
handles.register_quantized_tensor::<B>(
&[&self.desc.out.id, &self.desc.qparams.scale.id],
output,
);
}
QuantizedKind {
tensor: self.desc.out.id,
scale: self.desc.qparams.scale.id,
offset: None,
}
};
handles.register_quantized_tensor::<B>(&q_ids, output);
}
}

Expand Down
17 changes: 14 additions & 3 deletions crates/burn-fusion/src/server.rs
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,8 @@ use crate::{
FusionBackend, FusionRuntime,
};
use burn_tensor::repr::{
HandleContainer, OperationDescription, QuantizedTensorDescription, TensorDescription, TensorId,
HandleContainer, OperationDescription, QuantizedKind, QuantizedTensorDescription,
TensorDescription, TensorId,
};
use std::sync::Arc;

Expand Down Expand Up @@ -183,18 +184,28 @@ where
let scale_id = server_device.create_empty_handle();
let offset_id = server_device.create_empty_handle();

let q_ids = QuantizedKind {
tensor: *tensor_id,
scale: *scale_id,
offset: Some(*offset_id),
};
server_device
.handles
.register_quantized_tensor::<B>(&[&tensor_id, &scale_id, &offset_id], tensor);
.register_quantized_tensor::<B>(&q_ids, tensor);

vec![tensor_id, scale_id, offset_id]
} else {
let tensor_id = server_device.create_empty_handle();
let scale_id = server_device.create_empty_handle();

let q_ids = QuantizedKind {
tensor: *tensor_id,
scale: *scale_id,
offset: None,
};
server_device
.handles
.register_quantized_tensor::<B>(&[&tensor_id, &scale_id], tensor);
.register_quantized_tensor::<B>(&q_ids, tensor);

vec![tensor_id, scale_id]
}
Expand Down
61 changes: 21 additions & 40 deletions crates/burn-jit/src/fusion/base.rs
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,7 @@ use crate::{
};
use burn_fusion::{client::MutexFusionClient, FusionBackend, FusionRuntime};
use burn_tensor::quantization::QuantizationScheme;
use burn_tensor::repr::TensorHandle;
use burn_tensor::repr::{QuantizedKind, TensorHandle};
use burn_tensor::{repr::ReprBackend, Shape};
use core::marker::PhantomData;
use cubecl::client::ComputeClient;
Expand Down Expand Up @@ -79,41 +79,22 @@ impl<R: JitRuntime, F: FloatElement, I: IntElement> ReprBackend for JitBackend<R
}

fn quantized_tensor(
handles: Vec<TensorHandle<Self::Handle>>,
handles: QuantizedKind<TensorHandle<Self::Handle>>,
scheme: QuantizationScheme,
) -> burn_tensor::ops::QuantizedTensor<Self> {
match handles.len() {
// NOTE: the order of the handles is known [qtensor, scale, <offset>]
3 => {
let mut handles = handles;
let offset = handles.pop().unwrap();
let scale = handles.pop().unwrap();
let qtensor = handles.pop().unwrap();
QJitTensor {
qtensor: qtensor.handle.into_tensor(qtensor.shape),
scheme,
qparams: JitQuantizationParameters {
scale: scale.handle.into_tensor(scale.shape),
offset: Some(offset.handle.into_tensor(offset.shape)),
},
}
}
2 => {
let mut handles = handles;
let scale = handles.pop().unwrap();
let qtensor = handles.pop().unwrap();
QJitTensor {
qtensor: qtensor.handle.into_tensor(qtensor.shape),
scheme,
qparams: JitQuantizationParameters {
scale: scale.handle.into_tensor(scale.shape),
offset: None,
},
}
}
_ => {
panic!("Expected handles for the quantized tensor and its quantization parameters.")
}
let qtensor = handles.tensor.handle.into_tensor(handles.tensor.shape);
let scale = handles.scale.handle.into_tensor(handles.scale.shape);
let offset = handles.offset;

let qparams = JitQuantizationParameters {
scale,
offset: offset.map(|h| h.handle.into_tensor(h.shape)),
};

QJitTensor {
qtensor,
scheme,
qparams,
}
}

Expand All @@ -131,14 +112,14 @@ impl<R: JitRuntime, F: FloatElement, I: IntElement> ReprBackend for JitBackend<R

fn quantized_tensor_handle(
tensor: burn_tensor::ops::QuantizedTensor<Self>,
) -> Vec<Self::Handle> {
) -> QuantizedKind<Self::Handle> {
let qtensor: JitFusionHandle<R> = tensor.qtensor.into();
let scale: JitFusionHandle<R> = tensor.qparams.scale.into();
if let Some(offset) = tensor.qparams.offset {
let offset: JitFusionHandle<R> = offset.into();
vec![qtensor, scale, offset]
} else {
vec![qtensor, scale]

QuantizedKind {
tensor: qtensor,
scale,
offset: tensor.qparams.offset.map(|offset| offset.into()),
}
}
}
Expand Down
19 changes: 15 additions & 4 deletions crates/burn-tensor/src/repr/backend.rs
Original file line number Diff line number Diff line change
Expand Up @@ -4,16 +4,27 @@ use crate::{
quantization::QuantizationScheme,
Shape,
};
use alloc::vec::Vec;

/// A tensor representation containing a reference to a tensor resource with a given shape.
pub struct TensorHandle<H> {
#[derive(Clone)]
pub struct TensorHandle<H: Clone> {
/// The type that can be used to point to a tensor of any kind.
pub handle: H,
/// The shape associated to the tensor.
pub shape: Shape,
}

/// A simple struct to encapsulate a quantized tensor kind.
#[derive(Clone)]
pub struct QuantizedKind<T: Clone> {
/// The quantized tensor.
pub tensor: T,
/// The scaling factor.
pub scale: T,
/// The zero-point offset.
pub offset: Option<T>,
}

/// Backend extension trait that allows an existing [backend](Backend) to use the Burn tensor representation
/// for compilation purpose or other...
pub trait ReprBackend: Backend {
Expand All @@ -28,7 +39,7 @@ pub trait ReprBackend: Backend {
fn bool_tensor(handle: TensorHandle<Self::Handle>) -> BoolTensor<Self>;
/// Convert a [handle](ReprBackend::Handle) to a [quantized tensor](Backend::QuantizedTensorPrimitive).
fn quantized_tensor(
handles: Vec<TensorHandle<Self::Handle>>,
handle: QuantizedKind<TensorHandle<Self::Handle>>,
scheme: QuantizationScheme,
) -> QuantizedTensor<Self>;

Expand All @@ -40,5 +51,5 @@ pub trait ReprBackend: Backend {
fn bool_tensor_handle(tensor: BoolTensor<Self>) -> Self::Handle;
/// Convert a [quantized tensor](Backend::QuantizedTensorPrimitive) to a [handle](ReprBackend::Handle).
/// A quantized tensor has multiple handles for the tensor itself and the quantization parameters.
fn quantized_tensor_handle(tensor: QuantizedTensor<Self>) -> Vec<Self::Handle>;
fn quantized_tensor_handle(tensor: QuantizedTensor<Self>) -> QuantizedKind<Self::Handle>;
}
34 changes: 18 additions & 16 deletions crates/burn-tensor/src/repr/handle.rs
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,7 @@ use crate::{
};
use std::{collections::HashMap, sync::Arc};

use super::{QuantizedTensorDescription, TensorHandle};
use super::{QuantizedKind, QuantizedTensorDescription, TensorHandle};

/// Keep all [tensor handles](ReprBackend::Handle) in one place and ensure that all resources
/// are used optimally.
Expand Down Expand Up @@ -122,12 +122,14 @@ impl<H: Clone> HandleContainer<H> {
where
B: ReprBackend<Handle = H>,
{
let qtensor = self.get_tensor_handle(&tensor.tensor);
let scale = self.get_tensor_handle(&tensor.qparams.scale);
let handles = if let Some(offset) = &tensor.qparams.offset {
vec![qtensor, scale, self.get_tensor_handle(offset)]
} else {
vec![qtensor, scale]
let handles = QuantizedKind {
tensor: self.get_tensor_handle(&tensor.tensor),
scale: self.get_tensor_handle(&tensor.qparams.scale),
offset: tensor
.qparams
.offset
.as_ref()
.map(|offset| self.get_tensor_handle(offset)),
};
B::quantized_tensor(handles, tensor.scheme.clone())
}
Expand All @@ -144,20 +146,20 @@ impl<H: Clone> HandleContainer<H> {
/// Register a new [quantized tensor](crate::backend::Backend::QuantizedTensorPrimitive) with the corresponding [tensor ids](TensorId).
pub fn register_quantized_tensor<B>(
&mut self,
ids: &[&TensorId],
id: &QuantizedKind<TensorId>,
tensor: B::QuantizedTensorPrimitive,
) where
B: ReprBackend<Handle = H>,
{
let handles = B::quantized_tensor_handle(tensor);
assert_eq!(
ids.len(),
handles.len(),
"Number of tensor ids and handles must match"
);

for (handle, id) in handles.into_iter().zip(ids) {
self.handles.insert(**id, Handle::Existing(handle));

self.handles
.insert(id.tensor, Handle::Existing(handles.tensor));
self.handles
.insert(id.scale, Handle::Existing(handles.scale));

if let (Some(id), Some(handle)) = (id.offset, handles.offset) {
self.handles.insert(id, Handle::Existing(handle));
}
}

Expand Down

0 comments on commit 19199af

Please sign in to comment.