Skip to content

Conversation

skewballfox
Copy link
Contributor

@skewballfox skewballfox commented Aug 24, 2025

It's a bit early but could definitely use feedback on what works and doesn't in terms of the design

I think the current goals are:

  1. keep the implementation for complex numbers out of the core backend trait, and make it a decorator similar to autodiff.
  2. don't constrain backends to a specific layout, but if possible find a way to define layout dependent shared behavior (like butterfly operations)

Checklist

  • Confirmed that cargo run-checks command has been executed.
  • Made sure the book is up to date with changes in this PR.

Related Issues/PRs

Changes

main changes in relation to the goals are:

  1. made a new burn-complex crate that will house a lot of the shared definitions. I'm guessing most of the stuff other than the ComplexTensorBackend trait and dtype for complex numbers will be moved here
  2. Added a dummy trait ComplexLayout that will be implemented on unit structs that indicate what type of complex layout is in use for an implementation, which allows implementors to define functions and traits only meant to be used for a specific data layout.

Testing

TODO

Notes

mostly for me to pick up where I left off

  • need to figure out where TensorData is used, and work out how to handle the implementation for complex numbers given some implementations may not be contiguous.

prakash-shekhar and others added 8 commits June 29, 2025 11:45
- Add Complex as first-class TensorKind alongside Float, Int, Bool
- Add ComplexTensorPrimitive and ComplexElem to Backend trait
- Add complex tensor type aliases and exports
- Add ComplexTensorPrimitive support to NdArray backend
- Implement complex arithmetic and transcendental functions in NdArray backend
- Add autodiff backend wrapper for complex tensors
- Begin enabling support across backend ecosystem
- Add high-level Tensor<B, D, Complex> API with BasicOps and Numeric traits
- Add complex-specific methods: conj(), real(), imag(), magnitude(), phase()
- Add creation utilities: from_parts(), from_polar(), zeros(), ones()
- Start adding test suite covering operations
- Remove non-existent testgen_complex\!() macro call that was causing compilation errors
- Add ComplexTensorOps implementations for all backends (tch, candle, cubecl, fusion, router)
- Fix complex tensor assertion logic in CubeCL backend to avoid Float trait requirements
- Add missing transcendental functions (exp, log, sin, cos, tan, sqrt, powc) to Complex tensor API
@skewballfox
Copy link
Contributor Author

@laggui
so I'm trying to figure out what should go in the burn-tensor crate and what should go into the burn-complex crate, I could use your help in figuring out how the divide should go.

  • It looks like I need to have the Kind stay in burn-tensor since not all the required types are made public
  • complexbackend needs to stay in burn-tensor since I don't believe I can declare a supertrait on an external type, and due to the dependency so does the ComplexTensorOps trait
  • since I can't move the complex unit struct used in kind with Numeric impl, then basic ops needs to also stay in burn-tensor.
  • I can move the actual layout structs into burn-complex, which is probably where the will eventually need to be anyway if we want layout specific ops (which we likely do for generic ffts across backends), but they can also stay with the complextensorbackend declaration.

alternatively, If I make some more stuff in burn tensor public (such as kind), then I can lift almost everything out of burn-tensor

@laggui laggui self-requested a review August 28, 2025 02:46
Copy link
Member

@laggui laggui left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@skewballfox for now, I think we should treat the complex backend as an extension. Thus, almost all types, traits and impl should only live in the burn-complex extension at this time.

The only part that should be added to burn-tensor is the DType variant, since we don't currently have a way to add/support custom dtypes anyway.

I believe the rest can easily live as an extension in a separate crate.

pub type ComplexTensor<B> = <B as ComplexTensorBackend>::ComplexTensorPrimitive;

pub trait ComplexTensorBackend: Backend {
    /// The inner backend type.
    type InnerBackend: Backend<Device = Self::Device, FloatElem = Self::FloatElem>;
    /// Tensor primitive to be used for all complex operations.
    type ComplexTensorPrimitive: TensorMetadata + 'static;

    /// Returns the real part of a complex tensor.
    fn real(tensor: ComplexTensor<Self>) -> FloatTensor<Self::InnerBackend>;
    /// Returns the imaginary part of a complex tensor.
    fn imag(tensor: ComplexTensor<Self>) -> FloatTensor<Self::InnerBackend>;

    fn to_complex(tensor: FloatTensor<Self::InnerBackend>) -> ComplexTensor<Self>;
}

/// A type-level representation of the kind of a complex tensor.
#[derive(Clone, Debug)]
pub struct Complex;

impl<B: ComplexTensorBackend> TensorKind<B> for Complex {
    type Primitive = B::ComplexTensorPrimitive;
    fn name() -> &'static str {
        "Complex"
    }
}

You can easily implement the tensor ops traits for the Complex type, e.g.

impl<B: ComplexTensorBackend> BasicOps<B> for Complex {
    // ...
}

For the element type, I think the make_elem! macro will not work because ToElement does not have to_complex, but that is fine. The macro was mostly to avoid repeating the implementation, we can either implement the Element trait manually for these types or make the macro a bit more flexible for custom external types. Maybe that will require adding ToComplex (in the complex crate), and implement it for types that implement ToElement, so we can convert types <> complex.

Then, for the concrete implementations we can have feature flags similar to burn-vision

#[cfg(feature = "ndarray")]
mod ndarray {
    use crate::ComplexTensorBackend;
    use burn_ndarray::{
        FloatNdArrayElement, IntNdArrayElement, NdArray, NdArrayTensorFloat, QuantElement,
    };

    impl<E: FloatNdArrayElement, I: IntNdArrayElement, Q: QuantElement> ComplexTensorBackend
        for NdArray<E, I, Q>
    {
        // ...
    }
}

So we can limit the extension incrementally.

Eventually, we might move it as a core feature/trait, but I believe starting as an extension is the right approach to limit the scope.

CC'ing @nathanielsimard in case you have other thoughts on this subject.

@skewballfox
Copy link
Contributor Author

@laggui I'm currently running into issues with the Complex{32,64} needing to implement the element trait for BasicOps. I can't implement elementConversion and a few others I've duplicated in burn-complex without moving the complex elements themselves into burn-tensor.

I think moving complex element into burn-tensor and leaving everything else in burn-complex is probably the right approach, but I'm open to other suggestions. An alternative is create a new trait that would be shared by Element and ComplexElement and make that the requirement for basic ops, but that seems like the wrong approach to take

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

Complex numbers

3 participants