Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 3 additions & 0 deletions crates/burn-autodiff/src/backend.rs
Original file line number Diff line number Diff line change
Expand Up @@ -33,6 +33,9 @@ impl<B: Backend, C: CheckpointStrategy> Backend for Autodiff<B, C> {
type BoolTensorPrimitive = B::BoolTensorPrimitive;
type BoolElem = B::BoolElem;

type ComplexTensorPrimitive = B::ComplexTensorPrimitive;
type ComplexElem = B::ComplexElem;

type QuantizedTensorPrimitive = B::QuantizedTensorPrimitive;
type QuantizedEncoding = B::QuantizedEncoding;

Expand Down
136 changes: 136 additions & 0 deletions crates/burn-autodiff/src/ops/complex_tensor.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,136 @@
use crate::{Autodiff, checkpoint::strategy::CheckpointStrategy};
use burn_tensor::{
Distribution, Shape, TensorData,
backend::Backend,
ops::ComplexTensorOps,
ops::{ComplexTensor, Device},
};

impl<B: Backend, C: CheckpointStrategy> ComplexTensorOps<Self> for Autodiff<B, C> {
fn complex_from_data(data: TensorData, device: &Device<Self>) -> ComplexTensor<Self> {
B::complex_from_data(data, device)
}

fn complex_random(
shape: Shape,
distribution: Distribution,
device: &Device<Self>,
) -> ComplexTensor<Self> {
B::complex_random(shape, distribution, device)
}

fn complex_shape(tensor: &ComplexTensor<Self>) -> Shape {
B::complex_shape(tensor)
}

fn complex_to_data(tensor: &ComplexTensor<Self>) -> TensorData {
B::complex_to_data(tensor)
}

fn complex_device(tensor: &ComplexTensor<Self>) -> Device<Self> {
B::complex_device(tensor)
}

fn complex_to_device(
tensor: ComplexTensor<Self>,
device: &Device<Self>,
) -> ComplexTensor<Self> {
B::complex_to_device(tensor, device)
}

fn complex_into_data(tensor: ComplexTensor<Self>) -> TensorData {
B::complex_into_data(tensor)
}

fn complex_reshape(tensor: ComplexTensor<Self>, shape: Shape) -> ComplexTensor<Self> {
B::complex_reshape(tensor, shape)
}

fn complex_transpose(tensor: ComplexTensor<Self>) -> ComplexTensor<Self> {
B::complex_transpose(tensor)
}

fn complex_add(lhs: ComplexTensor<Self>, rhs: ComplexTensor<Self>) -> ComplexTensor<Self> {
B::complex_add(lhs, rhs)
}

fn complex_sub(lhs: ComplexTensor<Self>, rhs: ComplexTensor<Self>) -> ComplexTensor<Self> {
B::complex_sub(lhs, rhs)
}

fn complex_mul(lhs: ComplexTensor<Self>, rhs: ComplexTensor<Self>) -> ComplexTensor<Self> {
B::complex_mul(lhs, rhs)
}

fn complex_div(lhs: ComplexTensor<Self>, rhs: ComplexTensor<Self>) -> ComplexTensor<Self> {
B::complex_div(lhs, rhs)
}

fn complex_neg(tensor: ComplexTensor<Self>) -> ComplexTensor<Self> {
B::complex_neg(tensor)
}

fn complex_conj(tensor: ComplexTensor<Self>) -> ComplexTensor<Self> {
B::complex_conj(tensor)
}

fn complex_real(_tensor: ComplexTensor<Self>) -> <Self as Backend>::FloatTensorPrimitive {
// Since autodiff complex tensors are just the inner backend's complex tensors,
// and complex_real returns a float tensor, we need to convert it to an autodiff float tensor
todo!("Need to implement autodiff wrapper for complex_real")
}

fn complex_imag(_tensor: ComplexTensor<Self>) -> <Self as Backend>::FloatTensorPrimitive {
todo!("Need to implement autodiff wrapper for complex_imag")
}

fn complex_abs(_tensor: ComplexTensor<Self>) -> <Self as Backend>::FloatTensorPrimitive {
todo!("Need to implement autodiff wrapper for complex_abs")
}

fn complex_arg(_tensor: ComplexTensor<Self>) -> <Self as Backend>::FloatTensorPrimitive {
todo!("Need to implement autodiff wrapper for complex_arg")
}

fn complex_from_parts(
_real: <Self as Backend>::FloatTensorPrimitive,
_imag: <Self as Backend>::FloatTensorPrimitive,
) -> ComplexTensor<Self> {
todo!("Need to implement autodiff wrapper for complex_from_parts")
}

fn complex_from_polar(
_magnitude: <Self as Backend>::FloatTensorPrimitive,
_phase: <Self as Backend>::FloatTensorPrimitive,
) -> ComplexTensor<Self> {
todo!("Need to implement autodiff wrapper for complex_from_polar")
}

fn complex_exp(tensor: ComplexTensor<Self>) -> ComplexTensor<Self> {
B::complex_exp(tensor)
}

fn complex_log(tensor: ComplexTensor<Self>) -> ComplexTensor<Self> {
B::complex_log(tensor)
}

fn complex_powc(lhs: ComplexTensor<Self>, rhs: ComplexTensor<Self>) -> ComplexTensor<Self> {
B::complex_powc(lhs, rhs)
}

fn complex_sqrt(tensor: ComplexTensor<Self>) -> ComplexTensor<Self> {
B::complex_sqrt(tensor)
}

fn complex_sin(tensor: ComplexTensor<Self>) -> ComplexTensor<Self> {
B::complex_sin(tensor)
}

fn complex_cos(tensor: ComplexTensor<Self>) -> ComplexTensor<Self> {
B::complex_cos(tensor)
}

fn complex_tan(tensor: ComplexTensor<Self>) -> ComplexTensor<Self> {
B::complex_tan(tensor)
}
}
1 change: 1 addition & 0 deletions crates/burn-autodiff/src/ops/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@ mod activation;
mod backward;
mod base;
mod bool_tensor;
mod complex_tensor;
mod int_tensor;
mod module;
mod qtensor;
Expand Down
3 changes: 3 additions & 0 deletions crates/burn-candle/src/backend.rs
Original file line number Diff line number Diff line change
Expand Up @@ -170,6 +170,9 @@ impl<F: FloatCandleElement, I: IntCandleElement> Backend for Candle<F, I> {
type BoolTensorPrimitive = CandleTensor;
type BoolElem = u8;

type ComplexTensorPrimitive = CandleTensor;
type ComplexElem = burn_tensor::Complex32;

type QuantizedTensorPrimitive = CandleQTensor;
type QuantizedEncoding = u8;

Expand Down
131 changes: 131 additions & 0 deletions crates/burn-candle/src/ops/complex_tensor.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,131 @@
use crate::{
Candle, CandleTensor,
element::{FloatCandleElement, IntCandleElement},
};
use burn_tensor::{Device, Distribution, Shape, TensorData, ops::ComplexTensorOps};

impl<F: FloatCandleElement, I: IntCandleElement> ComplexTensorOps<Self> for Candle<F, I> {
fn complex_from_data(data: TensorData, _device: &Device<Self>) -> CandleTensor {
unimplemented!("Complex tensor operations are not yet implemented for Candle backend")
}

fn complex_random(
_shape: Shape,
_distribution: Distribution,
_device: &Device<Self>,
) -> CandleTensor {
unimplemented!("Complex tensor operations are not yet implemented for Candle backend")
}

fn complex_full(
_shape: Shape,
_fill_value: burn_tensor::Complex32,
_device: &Device<Self>,
) -> CandleTensor {
unimplemented!("Complex tensor operations are not yet implemented for Candle backend")
}

fn complex_shape(_tensor: &CandleTensor) -> Shape {
unimplemented!("Complex tensor operations are not yet implemented for Candle backend")
}

fn complex_to_data(_tensor: &CandleTensor) -> TensorData {
unimplemented!("Complex tensor operations are not yet implemented for Candle backend")
}

fn complex_device(_tensor: &CandleTensor) -> Device<Self> {
unimplemented!("Complex tensor operations are not yet implemented for Candle backend")
}

fn complex_to_device(_tensor: CandleTensor, _device: &Device<Self>) -> CandleTensor {
unimplemented!("Complex tensor operations are not yet implemented for Candle backend")
}

fn complex_into_data(_tensor: CandleTensor) -> TensorData {
unimplemented!("Complex tensor operations are not yet implemented for Candle backend")
}

fn complex_reshape(_tensor: CandleTensor, _shape: Shape) -> CandleTensor {
unimplemented!("Complex tensor operations are not yet implemented for Candle backend")
}

fn complex_transpose(_tensor: CandleTensor) -> CandleTensor {
unimplemented!("Complex tensor operations are not yet implemented for Candle backend")
}

fn complex_add(_lhs: CandleTensor, _rhs: CandleTensor) -> CandleTensor {
unimplemented!("Complex tensor operations are not yet implemented for Candle backend")
}

fn complex_sub(_lhs: CandleTensor, _rhs: CandleTensor) -> CandleTensor {
unimplemented!("Complex tensor operations are not yet implemented for Candle backend")
}

fn complex_mul(_lhs: CandleTensor, _rhs: CandleTensor) -> CandleTensor {
unimplemented!("Complex tensor operations are not yet implemented for Candle backend")
}

fn complex_div(_lhs: CandleTensor, _rhs: CandleTensor) -> CandleTensor {
unimplemented!("Complex tensor operations are not yet implemented for Candle backend")
}

fn complex_neg(_tensor: CandleTensor) -> CandleTensor {
unimplemented!("Complex tensor operations are not yet implemented for Candle backend")
}

fn complex_conj(_tensor: CandleTensor) -> CandleTensor {
unimplemented!("Complex tensor operations are not yet implemented for Candle backend")
}

fn complex_real(_tensor: CandleTensor) -> CandleTensor {
unimplemented!("Complex tensor operations are not yet implemented for Candle backend")
}

fn complex_imag(_tensor: CandleTensor) -> CandleTensor {
unimplemented!("Complex tensor operations are not yet implemented for Candle backend")
}

fn complex_abs(_tensor: CandleTensor) -> CandleTensor {
unimplemented!("Complex tensor operations are not yet implemented for Candle backend")
}

fn complex_arg(_tensor: CandleTensor) -> CandleTensor {
unimplemented!("Complex tensor operations are not yet implemented for Candle backend")
}

fn complex_from_parts(_real: CandleTensor, _imag: CandleTensor) -> CandleTensor {
unimplemented!("Complex tensor operations are not yet implemented for Candle backend")
}

fn complex_from_polar(_magnitude: CandleTensor, _phase: CandleTensor) -> CandleTensor {
unimplemented!("Complex tensor operations are not yet implemented for Candle backend")
}

fn complex_exp(_tensor: CandleTensor) -> CandleTensor {
unimplemented!("Complex tensor operations are not yet implemented for Candle backend")
}

fn complex_log(_tensor: CandleTensor) -> CandleTensor {
unimplemented!("Complex tensor operations are not yet implemented for Candle backend")
}

fn complex_powc(_lhs: CandleTensor, _rhs: CandleTensor) -> CandleTensor {
unimplemented!("Complex tensor operations are not yet implemented for Candle backend")
}

fn complex_sqrt(_tensor: CandleTensor) -> CandleTensor {
unimplemented!("Complex tensor operations are not yet implemented for Candle backend")
}

fn complex_sin(_tensor: CandleTensor) -> CandleTensor {
unimplemented!("Complex tensor operations are not yet implemented for Candle backend")
}

fn complex_cos(_tensor: CandleTensor) -> CandleTensor {
unimplemented!("Complex tensor operations are not yet implemented for Candle backend")
}

fn complex_tan(_tensor: CandleTensor) -> CandleTensor {
unimplemented!("Complex tensor operations are not yet implemented for Candle backend")
}
}
1 change: 1 addition & 0 deletions crates/burn-candle/src/ops/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@ mod activation;
mod base;
mod bool_tensor;
mod candle_utils;
mod complex_tensor;
mod int_tensor;
mod module;
mod qtensor;
Expand Down
2 changes: 2 additions & 0 deletions crates/burn-cubecl/src/backend.rs
Original file line number Diff line number Diff line change
Expand Up @@ -38,6 +38,8 @@ where
type FloatTensorPrimitive = CubeTensor<R>;
type IntTensorPrimitive = CubeTensor<R>;
type BoolTensorPrimitive = CubeTensor<R>;
type ComplexTensorPrimitive = CubeTensor<R>;
type ComplexElem = burn_tensor::Complex32;
type QuantizedTensorPrimitive = CubeTensor<R>;
type QuantizedEncoding = u32;

Expand Down
8 changes: 8 additions & 0 deletions crates/burn-cubecl/src/fusion.rs
Original file line number Diff line number Diff line change
Expand Up @@ -115,6 +115,10 @@ impl<R: CubeRuntime, F: FloatElement, I: IntElement, BT: BoolElement> BackendIr
into_tensor(handle.handle, handle.shape)
}

fn complex_tensor(handle: TensorHandle<Self::Handle>) -> burn_tensor::ops::ComplexTensor<Self> {
into_tensor(handle.handle, handle.shape)
}

fn float_tensor_handle(tensor: burn_tensor::ops::FloatTensor<Self>) -> Self::Handle {
tensor.into()
}
Expand All @@ -130,6 +134,10 @@ impl<R: CubeRuntime, F: FloatElement, I: IntElement, BT: BoolElement> BackendIr
fn quantized_tensor_handle(tensor: burn_tensor::ops::QuantizedTensor<Self>) -> Self::Handle {
tensor.into()
}

fn complex_tensor_handle(tensor: burn_tensor::ops::ComplexTensor<Self>) -> Self::Handle {
tensor.into()
}
}

impl<R: CubeRuntime, BT: BoolElement> FusionRuntime for FusionCubeRuntime<R, BT> {
Expand Down
Loading
Loading