-
Notifications
You must be signed in to change notification settings - Fork 718
Complex Backend #3608
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: main
Are you sure you want to change the base?
Complex Backend #3608
Conversation
- Add Complex as first-class TensorKind alongside Float, Int, Bool - Add ComplexTensorPrimitive and ComplexElem to Backend trait - Add complex tensor type aliases and exports
- Add ComplexTensorPrimitive support to NdArray backend - Implement complex arithmetic and transcendental functions in NdArray backend - Add autodiff backend wrapper for complex tensors - Begin enabling support across backend ecosystem
- Add high-level Tensor<B, D, Complex> API with BasicOps and Numeric traits - Add complex-specific methods: conj(), real(), imag(), magnitude(), phase() - Add creation utilities: from_parts(), from_polar(), zeros(), ones() - Start adding test suite covering operations
- Remove non-existent testgen_complex\!() macro call that was causing compilation errors - Add ComplexTensorOps implementations for all backends (tch, candle, cubecl, fusion, router) - Fix complex tensor assertion logic in CubeCL backend to avoid Float trait requirements - Add missing transcendental functions (exp, log, sin, cos, tan, sqrt, powc) to Complex tensor API
@laggui
alternatively, If I make some more stuff in burn tensor public (such as kind), then I can lift almost everything out of burn-tensor |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
@skewballfox for now, I think we should treat the complex backend as an extension. Thus, almost all types, traits and impl should only live in the burn-complex
extension at this time.
The only part that should be added to burn-tensor
is the DType
variant, since we don't currently have a way to add/support custom dtypes anyway.
I believe the rest can easily live as an extension in a separate crate.
pub type ComplexTensor<B> = <B as ComplexTensorBackend>::ComplexTensorPrimitive;
pub trait ComplexTensorBackend: Backend {
/// The inner backend type.
type InnerBackend: Backend<Device = Self::Device, FloatElem = Self::FloatElem>;
/// Tensor primitive to be used for all complex operations.
type ComplexTensorPrimitive: TensorMetadata + 'static;
/// Returns the real part of a complex tensor.
fn real(tensor: ComplexTensor<Self>) -> FloatTensor<Self::InnerBackend>;
/// Returns the imaginary part of a complex tensor.
fn imag(tensor: ComplexTensor<Self>) -> FloatTensor<Self::InnerBackend>;
fn to_complex(tensor: FloatTensor<Self::InnerBackend>) -> ComplexTensor<Self>;
}
/// A type-level representation of the kind of a complex tensor.
#[derive(Clone, Debug)]
pub struct Complex;
impl<B: ComplexTensorBackend> TensorKind<B> for Complex {
type Primitive = B::ComplexTensorPrimitive;
fn name() -> &'static str {
"Complex"
}
}
You can easily implement the tensor ops traits for the Complex
type, e.g.
impl<B: ComplexTensorBackend> BasicOps<B> for Complex {
// ...
}
For the element type, I think the make_elem!
macro will not work because ToElement
does not have to_complex
, but that is fine. The macro was mostly to avoid repeating the implementation, we can either implement the Element
trait manually for these types or make the macro a bit more flexible for custom external types. Maybe that will require adding ToComplex
(in the complex crate), and implement it for types that implement ToElement
, so we can convert types <> complex.
Then, for the concrete implementations we can have feature flags similar to burn-vision
#[cfg(feature = "ndarray")]
mod ndarray {
use crate::ComplexTensorBackend;
use burn_ndarray::{
FloatNdArrayElement, IntNdArrayElement, NdArray, NdArrayTensorFloat, QuantElement,
};
impl<E: FloatNdArrayElement, I: IntNdArrayElement, Q: QuantElement> ComplexTensorBackend
for NdArray<E, I, Q>
{
// ...
}
}
So we can limit the extension incrementally.
Eventually, we might move it as a core feature/trait, but I believe starting as an extension is the right approach to limit the scope.
CC'ing @nathanielsimard in case you have other thoughts on this subject.
@laggui I'm currently running into issues with the Complex{32,64} needing to implement the element trait for BasicOps. I can't implement elementConversion and a few others I've duplicated in I think moving complex element into burn-tensor and leaving everything else in burn-complex is probably the right approach, but I'm open to other suggestions. An alternative is create a new trait that would be shared by Element and ComplexElement and make that the requirement for basic ops, but that seems like the wrong approach to take |
It's a bit early but could definitely use feedback on what works and doesn't in terms of the design
I think the current goals are:
Checklist
cargo run-checks
command has been executed.Related Issues/PRs
Changes
main changes in relation to the goals are:
burn-complex
crate that will house a lot of the shared definitions. I'm guessing most of the stuff other than theComplexTensorBackend
trait and dtype for complex numbers will be moved hereComplexLayout
that will be implemented on unit structs that indicate what type of complex layout is in use for an implementation, which allows implementors to define functions and traits only meant to be used for a specific data layout.Testing
TODO
Notes
mostly for me to pick up where I left off
TensorData
is used, and work out how to handle the implementation for complex numbers given some implementations may not be contiguous.