From 19199afe7c60b35f7dfc0ca95663e2254b53422f Mon Sep 17 00:00:00 2001 From: Guillaume Lagrange Date: Tue, 15 Oct 2024 15:30:13 -0400 Subject: [PATCH] Refactor repr quantized tensor handle --- crates/burn-fusion/src/backend.rs | 6 +-- crates/burn-fusion/src/ops/qtensor.rs | 52 +++++++++++----------- crates/burn-fusion/src/server.rs | 17 +++++-- crates/burn-jit/src/fusion/base.rs | 61 +++++++++----------------- crates/burn-tensor/src/repr/backend.rs | 19 ++++++-- crates/burn-tensor/src/repr/handle.rs | 34 +++++++------- 6 files changed, 97 insertions(+), 92 deletions(-) diff --git a/crates/burn-fusion/src/backend.rs b/crates/burn-fusion/src/backend.rs index dec6ba1bea..438c7004f3 100644 --- a/crates/burn-fusion/src/backend.rs +++ b/crates/burn-fusion/src/backend.rs @@ -5,7 +5,7 @@ use crate::{ use burn_tensor::{ backend::{Backend, DeviceOps, SyncType}, ops::{BoolTensor, FloatTensor, IntTensor, QuantizedTensor}, - repr::{OperationDescription, ReprBackend, TensorHandle}, + repr::{OperationDescription, QuantizedKind, ReprBackend, TensorHandle}, Device, }; use serde::{de::DeserializeOwned, Serialize}; @@ -184,7 +184,7 @@ impl ReprBackend for Fusion { } fn quantized_tensor( - _handles: Vec>, + _handles: QuantizedKind>, _scheme: burn_tensor::quantization::QuantizationScheme, ) -> QuantizedTensor { todo!() // not as simple @@ -202,7 +202,7 @@ impl ReprBackend for Fusion { tensor } - fn quantized_tensor_handle(_tensor: QuantizedTensor) -> Vec { + fn quantized_tensor_handle(_tensor: QuantizedTensor) -> QuantizedKind { todo!() // not as simple } } diff --git a/crates/burn-fusion/src/ops/qtensor.rs b/crates/burn-fusion/src/ops/qtensor.rs index a6a51ce3a9..1f5f5e494a 100644 --- a/crates/burn-fusion/src/ops/qtensor.rs +++ b/crates/burn-fusion/src/ops/qtensor.rs @@ -6,6 +6,7 @@ use burn_tensor::{ repr::{ DequantizeOperationDescription, FloatOperationDescription, HandleContainer, OperationDescription, QuantizationParametersDescription, QuantizeOperationDescription, + QuantizedKind, }, DType, Device, Element, Shape, TensorData, }; @@ -25,19 +26,17 @@ impl QTensorOps for Fusion { let tensor = B::q_from_data(data, device); let shape = B::q_shape(&tensor); - let mut handles = B::quantized_tensor_handle(tensor); + let handles = B::quantized_tensor_handle(tensor); let qparams = match strategy { QuantizationStrategy::PerTensorAffineInt8(_) => { - let num_handles = handles.len(); - assert_eq!( - num_handles, 3, - "Expected 3 handles for quantized tensor, got {num_handles}" - ); - let offset = handles.pop().unwrap(); - let scale = handles.pop().unwrap(); + let offset = if let Some(offset) = handles.offset { + offset + } else { + panic!("Expected offset for quantized tensor."); + }; FusionQuantizationParameters { scale: client.register_tensor( - scale, + handles.scale, vec![1], StreamId::current(), B::FloatElem::dtype(), @@ -51,15 +50,13 @@ impl QTensorOps for Fusion { } } QuantizationStrategy::PerTensorSymmetricInt8(_) => { - let num_handles = handles.len(); - assert_eq!( - num_handles, 2, - "Expected 2 handles for quantized tensor, got {num_handles}" + assert!( + handles.offset.is_none(), + "Offset should not be provided for symmetric quantization." ); - let scale = handles.pop().unwrap(); FusionQuantizationParameters { scale: client.register_tensor( - scale, + handles.scale, vec![1], StreamId::current(), B::FloatElem::dtype(), @@ -69,7 +66,7 @@ impl QTensorOps for Fusion { } }; let qtensor = client.register_tensor( - handles.pop().unwrap(), + handles.tensor, shape.dims, StreamId::current(), B::QuantizedEncoding::dtype(), @@ -111,17 +108,20 @@ impl QTensorOps for Fusion { let qparams = QuantizationParametersPrimitive { scale, offset }; let output = B::quantize(tensor, &self.desc.scheme, qparams); - if let Some(offset) = &self.desc.qparams.offset { - handles.register_quantized_tensor::( - &[&self.desc.out.id, &self.desc.qparams.scale.id, &offset.id], - output, - ); + let q_ids = if let Some(offset) = &self.desc.qparams.offset { + QuantizedKind { + tensor: self.desc.out.id, + scale: self.desc.qparams.scale.id, + offset: Some(offset.id), + } } else { - handles.register_quantized_tensor::( - &[&self.desc.out.id, &self.desc.qparams.scale.id], - output, - ); - } + QuantizedKind { + tensor: self.desc.out.id, + scale: self.desc.qparams.scale.id, + offset: None, + } + }; + handles.register_quantized_tensor::(&q_ids, output); } } diff --git a/crates/burn-fusion/src/server.rs b/crates/burn-fusion/src/server.rs index 58cb7f9605..ce44d51c4a 100644 --- a/crates/burn-fusion/src/server.rs +++ b/crates/burn-fusion/src/server.rs @@ -3,7 +3,8 @@ use crate::{ FusionBackend, FusionRuntime, }; use burn_tensor::repr::{ - HandleContainer, OperationDescription, QuantizedTensorDescription, TensorDescription, TensorId, + HandleContainer, OperationDescription, QuantizedKind, QuantizedTensorDescription, + TensorDescription, TensorId, }; use std::sync::Arc; @@ -183,18 +184,28 @@ where let scale_id = server_device.create_empty_handle(); let offset_id = server_device.create_empty_handle(); + let q_ids = QuantizedKind { + tensor: *tensor_id, + scale: *scale_id, + offset: Some(*offset_id), + }; server_device .handles - .register_quantized_tensor::(&[&tensor_id, &scale_id, &offset_id], tensor); + .register_quantized_tensor::(&q_ids, tensor); vec![tensor_id, scale_id, offset_id] } else { let tensor_id = server_device.create_empty_handle(); let scale_id = server_device.create_empty_handle(); + let q_ids = QuantizedKind { + tensor: *tensor_id, + scale: *scale_id, + offset: None, + }; server_device .handles - .register_quantized_tensor::(&[&tensor_id, &scale_id], tensor); + .register_quantized_tensor::(&q_ids, tensor); vec![tensor_id, scale_id] } diff --git a/crates/burn-jit/src/fusion/base.rs b/crates/burn-jit/src/fusion/base.rs index 2b58692f3a..527c4bd98e 100644 --- a/crates/burn-jit/src/fusion/base.rs +++ b/crates/burn-jit/src/fusion/base.rs @@ -7,7 +7,7 @@ use crate::{ }; use burn_fusion::{client::MutexFusionClient, FusionBackend, FusionRuntime}; use burn_tensor::quantization::QuantizationScheme; -use burn_tensor::repr::TensorHandle; +use burn_tensor::repr::{QuantizedKind, TensorHandle}; use burn_tensor::{repr::ReprBackend, Shape}; use core::marker::PhantomData; use cubecl::client::ComputeClient; @@ -79,41 +79,22 @@ impl ReprBackend for JitBackend>, + handles: QuantizedKind>, scheme: QuantizationScheme, ) -> burn_tensor::ops::QuantizedTensor { - match handles.len() { - // NOTE: the order of the handles is known [qtensor, scale, ] - 3 => { - let mut handles = handles; - let offset = handles.pop().unwrap(); - let scale = handles.pop().unwrap(); - let qtensor = handles.pop().unwrap(); - QJitTensor { - qtensor: qtensor.handle.into_tensor(qtensor.shape), - scheme, - qparams: JitQuantizationParameters { - scale: scale.handle.into_tensor(scale.shape), - offset: Some(offset.handle.into_tensor(offset.shape)), - }, - } - } - 2 => { - let mut handles = handles; - let scale = handles.pop().unwrap(); - let qtensor = handles.pop().unwrap(); - QJitTensor { - qtensor: qtensor.handle.into_tensor(qtensor.shape), - scheme, - qparams: JitQuantizationParameters { - scale: scale.handle.into_tensor(scale.shape), - offset: None, - }, - } - } - _ => { - panic!("Expected handles for the quantized tensor and its quantization parameters.") - } + let qtensor = handles.tensor.handle.into_tensor(handles.tensor.shape); + let scale = handles.scale.handle.into_tensor(handles.scale.shape); + let offset = handles.offset; + + let qparams = JitQuantizationParameters { + scale, + offset: offset.map(|h| h.handle.into_tensor(h.shape)), + }; + + QJitTensor { + qtensor, + scheme, + qparams, } } @@ -131,14 +112,14 @@ impl ReprBackend for JitBackend, - ) -> Vec { + ) -> QuantizedKind { let qtensor: JitFusionHandle = tensor.qtensor.into(); let scale: JitFusionHandle = tensor.qparams.scale.into(); - if let Some(offset) = tensor.qparams.offset { - let offset: JitFusionHandle = offset.into(); - vec![qtensor, scale, offset] - } else { - vec![qtensor, scale] + + QuantizedKind { + tensor: qtensor, + scale, + offset: tensor.qparams.offset.map(|offset| offset.into()), } } } diff --git a/crates/burn-tensor/src/repr/backend.rs b/crates/burn-tensor/src/repr/backend.rs index 9853a2eafb..62ca14f9e3 100644 --- a/crates/burn-tensor/src/repr/backend.rs +++ b/crates/burn-tensor/src/repr/backend.rs @@ -4,16 +4,27 @@ use crate::{ quantization::QuantizationScheme, Shape, }; -use alloc::vec::Vec; /// A tensor representation containing a reference to a tensor resource with a given shape. -pub struct TensorHandle { +#[derive(Clone)] +pub struct TensorHandle { /// The type that can be used to point to a tensor of any kind. pub handle: H, /// The shape associated to the tensor. pub shape: Shape, } +/// A simple struct to encapsulate a quantized tensor kind. +#[derive(Clone)] +pub struct QuantizedKind { + /// The quantized tensor. + pub tensor: T, + /// The scaling factor. + pub scale: T, + /// The zero-point offset. + pub offset: Option, +} + /// Backend extension trait that allows an existing [backend](Backend) to use the Burn tensor representation /// for compilation purpose or other... pub trait ReprBackend: Backend { @@ -28,7 +39,7 @@ pub trait ReprBackend: Backend { fn bool_tensor(handle: TensorHandle) -> BoolTensor; /// Convert a [handle](ReprBackend::Handle) to a [quantized tensor](Backend::QuantizedTensorPrimitive). fn quantized_tensor( - handles: Vec>, + handle: QuantizedKind>, scheme: QuantizationScheme, ) -> QuantizedTensor; @@ -40,5 +51,5 @@ pub trait ReprBackend: Backend { fn bool_tensor_handle(tensor: BoolTensor) -> Self::Handle; /// Convert a [quantized tensor](Backend::QuantizedTensorPrimitive) to a [handle](ReprBackend::Handle). /// A quantized tensor has multiple handles for the tensor itself and the quantization parameters. - fn quantized_tensor_handle(tensor: QuantizedTensor) -> Vec; + fn quantized_tensor_handle(tensor: QuantizedTensor) -> QuantizedKind; } diff --git a/crates/burn-tensor/src/repr/handle.rs b/crates/burn-tensor/src/repr/handle.rs index 6160b45b3a..20225ed3f6 100644 --- a/crates/burn-tensor/src/repr/handle.rs +++ b/crates/burn-tensor/src/repr/handle.rs @@ -7,7 +7,7 @@ use crate::{ }; use std::{collections::HashMap, sync::Arc}; -use super::{QuantizedTensorDescription, TensorHandle}; +use super::{QuantizedKind, QuantizedTensorDescription, TensorHandle}; /// Keep all [tensor handles](ReprBackend::Handle) in one place and ensure that all resources /// are used optimally. @@ -122,12 +122,14 @@ impl HandleContainer { where B: ReprBackend, { - let qtensor = self.get_tensor_handle(&tensor.tensor); - let scale = self.get_tensor_handle(&tensor.qparams.scale); - let handles = if let Some(offset) = &tensor.qparams.offset { - vec![qtensor, scale, self.get_tensor_handle(offset)] - } else { - vec![qtensor, scale] + let handles = QuantizedKind { + tensor: self.get_tensor_handle(&tensor.tensor), + scale: self.get_tensor_handle(&tensor.qparams.scale), + offset: tensor + .qparams + .offset + .as_ref() + .map(|offset| self.get_tensor_handle(offset)), }; B::quantized_tensor(handles, tensor.scheme.clone()) } @@ -144,20 +146,20 @@ impl HandleContainer { /// Register a new [quantized tensor](crate::backend::Backend::QuantizedTensorPrimitive) with the corresponding [tensor ids](TensorId). pub fn register_quantized_tensor( &mut self, - ids: &[&TensorId], + id: &QuantizedKind, tensor: B::QuantizedTensorPrimitive, ) where B: ReprBackend, { let handles = B::quantized_tensor_handle(tensor); - assert_eq!( - ids.len(), - handles.len(), - "Number of tensor ids and handles must match" - ); - - for (handle, id) in handles.into_iter().zip(ids) { - self.handles.insert(**id, Handle::Existing(handle)); + + self.handles + .insert(id.tensor, Handle::Existing(handles.tensor)); + self.handles + .insert(id.scale, Handle::Existing(handles.scale)); + + if let (Some(id), Some(handle)) = (id.offset, handles.offset) { + self.handles.insert(id, Handle::Existing(handle)); } }