-
Notifications
You must be signed in to change notification settings - Fork 9
Open
Description
Summary
Implement automatic data layout selection for convolutions (NCHW vs NHWC vs custom) based on the backend and operation sequence, with lazy layout conversion at boundaries.
Inspiration: JAX jax/_src/lax/convolution.py:38-180 - supports multiple ConvDimensionNumbers for flexible data layouts.
Problem
Different hardware prefers different data layouts:
- Intel AVX: NCHW (channels first) for better vectorization
- ARM NEON: NHWC (channels last) for sequential access
- cuDNN: Depends on operation (some prefer NCHW, others NHWC)
- Fused ops: May prefer specific layouts
Current implementation forces a single layout, causing:
- Unnecessary transposes at operation boundaries
- Suboptimal memory access patterns
- Missed fusion opportunities
Proposed Solution
1. Layout Specification
/// Data layout for convolution tensors
#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash)]
pub struct ConvLayout {
/// Batch dimension index
pub batch: usize,
/// Channel dimension index
pub channel: usize,
/// Spatial dimension indices (height, width, ...)
pub spatial: [usize; 3], // Max 3D convolution
/// Number of spatial dimensions
pub num_spatial: usize,
}
impl ConvLayout {
/// NCHW layout (batch, channel, height, width)
pub const NCHW: Self = Self {
batch: 0,
channel: 1,
spatial: [2, 3, 0],
num_spatial: 2,
};
/// NHWC layout (batch, height, width, channel)
pub const NHWC: Self = Self {
batch: 0,
channel: 3,
spatial: [1, 2, 0],
num_spatial: 2,
};
/// NCDHW for 3D convolutions
pub const NCDHW: Self = Self {
batch: 0,
channel: 1,
spatial: [2, 3, 4],
num_spatial: 3,
};
/// Check if this is a "channels first" layout
pub fn is_channels_first(&self) -> bool {
self.channel < self.spatial[0]
}
/// Get dimension permutation from another layout
pub fn permutation_from(&self, other: &ConvLayout) -> Vec<usize> {
let mut perm = vec![0; 4.max(2 + self.num_spatial)];
perm[self.batch] = other.batch;
perm[self.channel] = other.channel;
for i in 0..self.num_spatial {
perm[self.spatial[i]] = other.spatial[i];
}
perm
}
}
/// Full dimension specification including kernel layout
#[derive(Debug, Clone)]
pub struct ConvDimensionNumbers {
/// Input tensor layout
pub input_layout: ConvLayout,
/// Kernel/filter layout (out_channels, in_channels, spatial...)
pub kernel_layout: KernelLayout,
/// Output tensor layout
pub output_layout: ConvLayout,
}
#[derive(Debug, Clone, Copy)]
pub struct KernelLayout {
/// Output channels dimension
pub out_channel: usize,
/// Input channels dimension
pub in_channel: usize,
/// Spatial dimensions
pub spatial: [usize; 3],
pub num_spatial: usize,
}
impl KernelLayout {
/// OIHW (out, in, height, width) - standard
pub const OIHW: Self = Self {
out_channel: 0,
in_channel: 1,
spatial: [2, 3, 0],
num_spatial: 2,
};
/// HWIO (height, width, in, out) - TensorFlow style
pub const HWIO: Self = Self {
out_channel: 3,
in_channel: 2,
spatial: [0, 1, 0],
num_spatial: 2,
};
}2. Layout-Aware Tensor
/// Tensor with explicit layout tracking
pub struct LayoutTensor<T, const N: usize> {
data: Tensor<T>,
layout: TensorLayout<N>,
}
#[derive(Debug, Clone)]
pub struct TensorLayout<const N: usize> {
/// Logical dimension order
pub dim_order: [usize; N],
/// Stride for each logical dimension
pub strides: [usize; N],
}
impl<T: TensorElement, const N: usize> LayoutTensor<T, N> {
/// Convert to different layout (lazy if possible)
pub fn to_layout(&self, target: TensorLayout<N>) -> Self {
if self.layout == target {
return self.clone();
}
// Check if we can use a view (just change strides)
if self.can_view_as(&target) {
return Self {
data: self.data.view_with_strides(target.strides),
layout: target,
};
}
// Need actual transpose
let perm = target.permutation_from(&self.layout);
Self {
data: self.data.permute(&perm),
layout: target,
}
}
/// Check if layout conversion can be done as a view
fn can_view_as(&self, target: &TensorLayout<N>) -> bool {
// Can view if target strides are a permutation of current strides
// and data is contiguous in memory
self.data.is_contiguous() &&
self.layout.is_permutation_of(target)
}
}3. Automatic Layout Selection
/// Layout preferences for different backends
pub struct LayoutPolicy {
/// Preferred input layout
pub preferred_input: ConvLayout,
/// Preferred kernel layout
pub preferred_kernel: KernelLayout,
/// Preferred output layout
pub preferred_output: ConvLayout,
/// Cost of layout conversion (in equivalent conv FLOPs)
pub conversion_cost_factor: f64,
}
impl LayoutPolicy {
/// Policy for x86_64 with AVX2
#[cfg(target_arch = "x86_64")]
pub fn default_cpu() -> Self {
Self {
preferred_input: ConvLayout::NCHW,
preferred_kernel: KernelLayout::OIHW,
preferred_output: ConvLayout::NCHW,
conversion_cost_factor: 0.1, // Transpose is 10% of conv cost
}
}
/// Policy for ARM with NEON
#[cfg(target_arch = "aarch64")]
pub fn default_cpu() -> Self {
Self {
preferred_input: ConvLayout::NHWC,
preferred_kernel: KernelLayout::HWIO,
preferred_output: ConvLayout::NHWC,
conversion_cost_factor: 0.1,
}
}
/// Policy for WebGPU
pub fn webgpu() -> Self {
Self {
preferred_input: ConvLayout::NHWC,
preferred_kernel: KernelLayout::HWIO,
preferred_output: ConvLayout::NHWC,
conversion_cost_factor: 0.05, // GPU transposes are cheaper
}
}
}
/// Select optimal layout for a sequence of operations
pub fn select_layout_for_sequence(
ops: &[ConvOp],
policy: &LayoutPolicy,
) -> Vec<ConvDimensionNumbers> {
// Dynamic programming to minimize total cost
let n = ops.len();
// State: (op_index, current_layout) -> min_cost
let layouts = [ConvLayout::NCHW, ConvLayout::NHWC];
let mut dp = vec![vec![f64::INFINITY; layouts.len()]; n + 1];
let mut parent = vec![vec![0usize; layouts.len()]; n + 1];
// Base case: start with preferred layout (cost 0) or convert
dp[0][0] = 0.0; // NCHW
dp[0][1] = policy.conversion_cost_factor; // NHWC requires convert
for i in 0..n {
for (j, &layout) in layouts.iter().enumerate() {
if dp[i][j] == f64::INFINITY {
continue;
}
// Cost of this op with this layout
let op_cost = estimate_conv_cost(&ops[i], layout, policy);
// Try each output layout
for (k, &out_layout) in layouts.iter().enumerate() {
let convert_cost = if layout == out_layout {
0.0
} else {
policy.conversion_cost_factor * ops[i].output_size() as f64
};
let total = dp[i][j] + op_cost + convert_cost;
if total < dp[i + 1][k] {
dp[i + 1][k] = total;
parent[i + 1][k] = j;
}
}
}
}
// Backtrack to find optimal sequence
let mut result = Vec::with_capacity(n);
let mut current = if dp[n][0] < dp[n][1] { 0 } else { 1 };
for i in (0..n).rev() {
let in_layout = layouts[parent[i + 1][current]];
let out_layout = layouts[current];
result.push(ConvDimensionNumbers {
input_layout: in_layout,
kernel_layout: policy.preferred_kernel,
output_layout: out_layout,
});
current = parent[i + 1][current];
}
result.reverse();
result
}4. Lazy Layout Conversion
/// Tensor that defers layout conversion until needed
pub struct LazyLayoutTensor<T> {
data: Tensor<T>,
current_layout: ConvLayout,
/// Pending layout conversion (if any)
pending_conversion: Option<ConvLayout>,
}
impl<T: TensorElement> LazyLayoutTensor<T> {
/// Request a layout change (deferred)
pub fn request_layout(&mut self, target: ConvLayout) {
if self.current_layout != target {
self.pending_conversion = Some(target);
}
}
/// Materialize any pending conversion
pub fn materialize(&mut self) {
if let Some(target) = self.pending_conversion.take() {
let perm = target.permutation_from(&self.current_layout);
self.data = self.data.permute(&perm);
self.current_layout = target;
}
}
/// Get data, materializing if needed
pub fn get(&mut self) -> &Tensor<T> {
self.materialize();
&self.data
}
/// Try to cancel pending conversion if next op prefers current layout
pub fn try_cancel_conversion(&mut self, preferred: ConvLayout) -> bool {
if self.current_layout == preferred {
self.pending_conversion = None;
true
} else {
false
}
}
}5. Layout-Optimized Convolution
/// Convolution with automatic layout handling
pub fn conv2d_auto_layout<T: TensorElement>(
input: &LayoutTensor<T, 4>,
kernel: &LayoutTensor<T, 4>,
stride: (usize, usize),
padding: Padding,
) -> LayoutTensor<T, 4> {
let policy = LayoutPolicy::default_cpu();
// Check if conversion is needed
let need_input_convert = input.layout() != policy.preferred_input;
let need_kernel_convert = kernel.layout() != policy.preferred_kernel;
// Estimate costs
let conv_flops = estimate_conv_flops(input.shape(), kernel.shape(), stride);
let convert_input_cost = if need_input_convert {
input.len() as f64 * policy.conversion_cost_factor
} else {
0.0
};
let convert_kernel_cost = if need_kernel_convert {
kernel.len() as f64 * policy.conversion_cost_factor
} else {
0.0
};
// Decide whether to convert or use slower non-preferred layout
let use_preferred = conv_flops as f64 > convert_input_cost + convert_kernel_cost;
if use_preferred {
// Convert and use optimized kernel
let input_converted = input.to_layout(policy.preferred_input);
let kernel_converted = kernel.to_layout(policy.preferred_kernel);
let output = conv2d_optimized(&input_converted, &kernel_converted, stride, padding);
LayoutTensor::new(output, policy.preferred_output)
} else {
// Use current layout with generic kernel
let output = conv2d_generic(input, kernel, stride, padding);
LayoutTensor::new(output, input.layout())
}
}6. Im2col with Layout Support
/// Im2col that respects input layout
pub fn im2col_with_layout<T: TensorElement>(
input: &LayoutTensor<T, 4>,
kernel_size: (usize, usize),
stride: (usize, usize),
padding: (usize, usize),
) -> Tensor<T> {
let layout = input.layout();
let shape = input.shape();
let (batch, channels, height, width) = if layout.is_channels_first() {
(shape[0], shape[1], shape[2], shape[3])
} else {
(shape[0], shape[3], shape[1], shape[2])
};
let out_h = (height + 2 * padding.0 - kernel_size.0) / stride.0 + 1;
let out_w = (width + 2 * padding.1 - kernel_size.1) / stride.1 + 1;
let col_size = channels * kernel_size.0 * kernel_size.1;
let mut col = Tensor::zeros(&[batch, col_size, out_h * out_w]);
// Layout-aware extraction
if layout.is_channels_first() {
im2col_nchw(input.data(), &mut col, kernel_size, stride, padding);
} else {
im2col_nhwc(input.data(), &mut col, kernel_size, stride, padding);
}
col
}Acceptance Criteria
-
ConvLayoutwith NCHW, NHWC, NCDHW variants -
KernelLayoutwith OIHW, HWIO variants -
ConvDimensionNumbersfor full specification -
LayoutTensorwith explicit layout tracking -
LayoutPolicyfor different backends - Dynamic programming layout sequence optimizer
-
LazyLayoutTensorwith deferred conversion - Im2col with layout support
- Benchmarks for different layout scenarios
Expected Performance Impact
| Scenario | Fixed Layout | Auto Layout | Improvement |
|---|---|---|---|
| Single conv (NCHW native) | 10ms | 10ms | 0% |
| Single conv (NHWC on NCHW) | 15ms | 12ms | 20% |
| Conv sequence (5 ops) | 75ms | 55ms | 27% |
| Mixed precision (convert chain) | 100ms | 60ms | 40% |
15-30% improvement for non-native layouts and multi-op sequences.
References
- JAX convolution:
jax/_src/lax/convolution.py:38-180 - cuDNN Tensor Formats: https://docs.nvidia.com/deeplearning/cudnn/api/index.html
- MKL-DNN Layouts: https://oneapi-src.github.io/oneDNN/
Labels
performance, convolution, layout, P1-high
Metadata
Metadata
Assignees
Labels
No labels