Skip to content

perf: Convolution Layout Optimization with Auto-Format Selection #159

@noahgift

Description

@noahgift

Summary

Implement automatic data layout selection for convolutions (NCHW vs NHWC vs custom) based on the backend and operation sequence, with lazy layout conversion at boundaries.

Inspiration: JAX jax/_src/lax/convolution.py:38-180 - supports multiple ConvDimensionNumbers for flexible data layouts.

Problem

Different hardware prefers different data layouts:

  • Intel AVX: NCHW (channels first) for better vectorization
  • ARM NEON: NHWC (channels last) for sequential access
  • cuDNN: Depends on operation (some prefer NCHW, others NHWC)
  • Fused ops: May prefer specific layouts

Current implementation forces a single layout, causing:

  • Unnecessary transposes at operation boundaries
  • Suboptimal memory access patterns
  • Missed fusion opportunities

Proposed Solution

1. Layout Specification

/// Data layout for convolution tensors
#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash)]
pub struct ConvLayout {
    /// Batch dimension index
    pub batch: usize,
    /// Channel dimension index
    pub channel: usize,
    /// Spatial dimension indices (height, width, ...)
    pub spatial: [usize; 3],  // Max 3D convolution
    /// Number of spatial dimensions
    pub num_spatial: usize,
}

impl ConvLayout {
    /// NCHW layout (batch, channel, height, width)
    pub const NCHW: Self = Self {
        batch: 0,
        channel: 1,
        spatial: [2, 3, 0],
        num_spatial: 2,
    };
    
    /// NHWC layout (batch, height, width, channel)
    pub const NHWC: Self = Self {
        batch: 0,
        channel: 3,
        spatial: [1, 2, 0],
        num_spatial: 2,
    };
    
    /// NCDHW for 3D convolutions
    pub const NCDHW: Self = Self {
        batch: 0,
        channel: 1,
        spatial: [2, 3, 4],
        num_spatial: 3,
    };
    
    /// Check if this is a "channels first" layout
    pub fn is_channels_first(&self) -> bool {
        self.channel < self.spatial[0]
    }
    
    /// Get dimension permutation from another layout
    pub fn permutation_from(&self, other: &ConvLayout) -> Vec<usize> {
        let mut perm = vec![0; 4.max(2 + self.num_spatial)];
        perm[self.batch] = other.batch;
        perm[self.channel] = other.channel;
        for i in 0..self.num_spatial {
            perm[self.spatial[i]] = other.spatial[i];
        }
        perm
    }
}

/// Full dimension specification including kernel layout
#[derive(Debug, Clone)]
pub struct ConvDimensionNumbers {
    /// Input tensor layout
    pub input_layout: ConvLayout,
    /// Kernel/filter layout (out_channels, in_channels, spatial...)
    pub kernel_layout: KernelLayout,
    /// Output tensor layout
    pub output_layout: ConvLayout,
}

#[derive(Debug, Clone, Copy)]
pub struct KernelLayout {
    /// Output channels dimension
    pub out_channel: usize,
    /// Input channels dimension
    pub in_channel: usize,
    /// Spatial dimensions
    pub spatial: [usize; 3],
    pub num_spatial: usize,
}

impl KernelLayout {
    /// OIHW (out, in, height, width) - standard
    pub const OIHW: Self = Self {
        out_channel: 0,
        in_channel: 1,
        spatial: [2, 3, 0],
        num_spatial: 2,
    };
    
    /// HWIO (height, width, in, out) - TensorFlow style
    pub const HWIO: Self = Self {
        out_channel: 3,
        in_channel: 2,
        spatial: [0, 1, 0],
        num_spatial: 2,
    };
}

2. Layout-Aware Tensor

/// Tensor with explicit layout tracking
pub struct LayoutTensor<T, const N: usize> {
    data: Tensor<T>,
    layout: TensorLayout<N>,
}

#[derive(Debug, Clone)]
pub struct TensorLayout<const N: usize> {
    /// Logical dimension order
    pub dim_order: [usize; N],
    /// Stride for each logical dimension
    pub strides: [usize; N],
}

impl<T: TensorElement, const N: usize> LayoutTensor<T, N> {
    /// Convert to different layout (lazy if possible)
    pub fn to_layout(&self, target: TensorLayout<N>) -> Self {
        if self.layout == target {
            return self.clone();
        }
        
        // Check if we can use a view (just change strides)
        if self.can_view_as(&target) {
            return Self {
                data: self.data.view_with_strides(target.strides),
                layout: target,
            };
        }
        
        // Need actual transpose
        let perm = target.permutation_from(&self.layout);
        Self {
            data: self.data.permute(&perm),
            layout: target,
        }
    }
    
    /// Check if layout conversion can be done as a view
    fn can_view_as(&self, target: &TensorLayout<N>) -> bool {
        // Can view if target strides are a permutation of current strides
        // and data is contiguous in memory
        self.data.is_contiguous() && 
            self.layout.is_permutation_of(target)
    }
}

3. Automatic Layout Selection

/// Layout preferences for different backends
pub struct LayoutPolicy {
    /// Preferred input layout
    pub preferred_input: ConvLayout,
    /// Preferred kernel layout
    pub preferred_kernel: KernelLayout,
    /// Preferred output layout
    pub preferred_output: ConvLayout,
    /// Cost of layout conversion (in equivalent conv FLOPs)
    pub conversion_cost_factor: f64,
}

impl LayoutPolicy {
    /// Policy for x86_64 with AVX2
    #[cfg(target_arch = "x86_64")]
    pub fn default_cpu() -> Self {
        Self {
            preferred_input: ConvLayout::NCHW,
            preferred_kernel: KernelLayout::OIHW,
            preferred_output: ConvLayout::NCHW,
            conversion_cost_factor: 0.1,  // Transpose is 10% of conv cost
        }
    }
    
    /// Policy for ARM with NEON
    #[cfg(target_arch = "aarch64")]
    pub fn default_cpu() -> Self {
        Self {
            preferred_input: ConvLayout::NHWC,
            preferred_kernel: KernelLayout::HWIO,
            preferred_output: ConvLayout::NHWC,
            conversion_cost_factor: 0.1,
        }
    }
    
    /// Policy for WebGPU
    pub fn webgpu() -> Self {
        Self {
            preferred_input: ConvLayout::NHWC,
            preferred_kernel: KernelLayout::HWIO,
            preferred_output: ConvLayout::NHWC,
            conversion_cost_factor: 0.05,  // GPU transposes are cheaper
        }
    }
}

/// Select optimal layout for a sequence of operations
pub fn select_layout_for_sequence(
    ops: &[ConvOp],
    policy: &LayoutPolicy,
) -> Vec<ConvDimensionNumbers> {
    // Dynamic programming to minimize total cost
    let n = ops.len();
    
    // State: (op_index, current_layout) -> min_cost
    let layouts = [ConvLayout::NCHW, ConvLayout::NHWC];
    let mut dp = vec![vec![f64::INFINITY; layouts.len()]; n + 1];
    let mut parent = vec![vec![0usize; layouts.len()]; n + 1];
    
    // Base case: start with preferred layout (cost 0) or convert
    dp[0][0] = 0.0;  // NCHW
    dp[0][1] = policy.conversion_cost_factor;  // NHWC requires convert
    
    for i in 0..n {
        for (j, &layout) in layouts.iter().enumerate() {
            if dp[i][j] == f64::INFINITY {
                continue;
            }
            
            // Cost of this op with this layout
            let op_cost = estimate_conv_cost(&ops[i], layout, policy);
            
            // Try each output layout
            for (k, &out_layout) in layouts.iter().enumerate() {
                let convert_cost = if layout == out_layout {
                    0.0
                } else {
                    policy.conversion_cost_factor * ops[i].output_size() as f64
                };
                
                let total = dp[i][j] + op_cost + convert_cost;
                if total < dp[i + 1][k] {
                    dp[i + 1][k] = total;
                    parent[i + 1][k] = j;
                }
            }
        }
    }
    
    // Backtrack to find optimal sequence
    let mut result = Vec::with_capacity(n);
    let mut current = if dp[n][0] < dp[n][1] { 0 } else { 1 };
    
    for i in (0..n).rev() {
        let in_layout = layouts[parent[i + 1][current]];
        let out_layout = layouts[current];
        
        result.push(ConvDimensionNumbers {
            input_layout: in_layout,
            kernel_layout: policy.preferred_kernel,
            output_layout: out_layout,
        });
        
        current = parent[i + 1][current];
    }
    
    result.reverse();
    result
}

4. Lazy Layout Conversion

/// Tensor that defers layout conversion until needed
pub struct LazyLayoutTensor<T> {
    data: Tensor<T>,
    current_layout: ConvLayout,
    /// Pending layout conversion (if any)
    pending_conversion: Option<ConvLayout>,
}

impl<T: TensorElement> LazyLayoutTensor<T> {
    /// Request a layout change (deferred)
    pub fn request_layout(&mut self, target: ConvLayout) {
        if self.current_layout != target {
            self.pending_conversion = Some(target);
        }
    }
    
    /// Materialize any pending conversion
    pub fn materialize(&mut self) {
        if let Some(target) = self.pending_conversion.take() {
            let perm = target.permutation_from(&self.current_layout);
            self.data = self.data.permute(&perm);
            self.current_layout = target;
        }
    }
    
    /// Get data, materializing if needed
    pub fn get(&mut self) -> &Tensor<T> {
        self.materialize();
        &self.data
    }
    
    /// Try to cancel pending conversion if next op prefers current layout
    pub fn try_cancel_conversion(&mut self, preferred: ConvLayout) -> bool {
        if self.current_layout == preferred {
            self.pending_conversion = None;
            true
        } else {
            false
        }
    }
}

5. Layout-Optimized Convolution

/// Convolution with automatic layout handling
pub fn conv2d_auto_layout<T: TensorElement>(
    input: &LayoutTensor<T, 4>,
    kernel: &LayoutTensor<T, 4>,
    stride: (usize, usize),
    padding: Padding,
) -> LayoutTensor<T, 4> {
    let policy = LayoutPolicy::default_cpu();
    
    // Check if conversion is needed
    let need_input_convert = input.layout() != policy.preferred_input;
    let need_kernel_convert = kernel.layout() != policy.preferred_kernel;
    
    // Estimate costs
    let conv_flops = estimate_conv_flops(input.shape(), kernel.shape(), stride);
    let convert_input_cost = if need_input_convert {
        input.len() as f64 * policy.conversion_cost_factor
    } else {
        0.0
    };
    let convert_kernel_cost = if need_kernel_convert {
        kernel.len() as f64 * policy.conversion_cost_factor
    } else {
        0.0
    };
    
    // Decide whether to convert or use slower non-preferred layout
    let use_preferred = conv_flops as f64 > convert_input_cost + convert_kernel_cost;
    
    if use_preferred {
        // Convert and use optimized kernel
        let input_converted = input.to_layout(policy.preferred_input);
        let kernel_converted = kernel.to_layout(policy.preferred_kernel);
        
        let output = conv2d_optimized(&input_converted, &kernel_converted, stride, padding);
        LayoutTensor::new(output, policy.preferred_output)
    } else {
        // Use current layout with generic kernel
        let output = conv2d_generic(input, kernel, stride, padding);
        LayoutTensor::new(output, input.layout())
    }
}

6. Im2col with Layout Support

/// Im2col that respects input layout
pub fn im2col_with_layout<T: TensorElement>(
    input: &LayoutTensor<T, 4>,
    kernel_size: (usize, usize),
    stride: (usize, usize),
    padding: (usize, usize),
) -> Tensor<T> {
    let layout = input.layout();
    let shape = input.shape();
    
    let (batch, channels, height, width) = if layout.is_channels_first() {
        (shape[0], shape[1], shape[2], shape[3])
    } else {
        (shape[0], shape[3], shape[1], shape[2])
    };
    
    let out_h = (height + 2 * padding.0 - kernel_size.0) / stride.0 + 1;
    let out_w = (width + 2 * padding.1 - kernel_size.1) / stride.1 + 1;
    
    let col_size = channels * kernel_size.0 * kernel_size.1;
    let mut col = Tensor::zeros(&[batch, col_size, out_h * out_w]);
    
    // Layout-aware extraction
    if layout.is_channels_first() {
        im2col_nchw(input.data(), &mut col, kernel_size, stride, padding);
    } else {
        im2col_nhwc(input.data(), &mut col, kernel_size, stride, padding);
    }
    
    col
}

Acceptance Criteria

  • ConvLayout with NCHW, NHWC, NCDHW variants
  • KernelLayout with OIHW, HWIO variants
  • ConvDimensionNumbers for full specification
  • LayoutTensor with explicit layout tracking
  • LayoutPolicy for different backends
  • Dynamic programming layout sequence optimizer
  • LazyLayoutTensor with deferred conversion
  • Im2col with layout support
  • Benchmarks for different layout scenarios

Expected Performance Impact

Scenario Fixed Layout Auto Layout Improvement
Single conv (NCHW native) 10ms 10ms 0%
Single conv (NHWC on NCHW) 15ms 12ms 20%
Conv sequence (5 ops) 75ms 55ms 27%
Mixed precision (convert chain) 100ms 60ms 40%

15-30% improvement for non-native layouts and multi-op sequences.

References

Labels

performance, convolution, layout, P1-high

Metadata

Metadata

Assignees

No one assigned

    Labels

    No labels
    No labels

    Type

    No type

    Projects

    No projects

    Milestone

    No milestone

    Relationships

    None yet

    Development

    No branches or pull requests

    Issue actions