-
Notifications
You must be signed in to change notification settings - Fork 1
Open
Labels
enhancementNew feature or requestNew feature or request
Description
Differentiable Modulations Integration: Transition from PR #10 to PR #41
Summary
This issue tracks the development of differentiable modulation and demodulation operations for neural network training scenarios. Originally planned for PR #10, development has been redirected to PR #41 due to compatibility issues.
Background
The goal is to add gradient-enabled modulation and demodulation operations to Kaira, enabling gradient-based training of neural networks that include modulation layers. This opens up new possibilities for end-to-end learnable communication systems.
Key Features to Implement
1. Differentiable Operations
- New Module:
kaira/modulations/differentiable.py
- Core utilities for differentiable operations - Extended Base Classes: Add
forward_soft
methods toBaseModulator
andBaseDemodulator
- Backward Compatibility: Existing
forward
methods remain unchanged
2. Supported Modulation Schemes
- BPSK: Differentiable Binary Phase Shift Keying
- QPSK: Differentiable Quadrature Phase Shift Keying
- QAM: Differentiable Quadrature Amplitude Modulation
3. Testing Framework
- Comprehensive test suite for differentiable operations
- Gradient verification tests
- End-to-end pipeline validation
Technical Implementation
Core API Extension
# Extended BaseModulator with differentiable path
class BaseModulator(nn.Module, ABC):
@abstractmethod
def forward_soft(self, soft_bits: torch.Tensor) -> torch.Tensor:
"""Differentiable modulation from soft bit probabilities to symbols"""
pass
# Extended BaseDemodulator with differentiable path
class BaseDemodulator(nn.Module, ABC):
@abstractmethod
def forward_soft(self, symbols: torch.Tensor, noise_var: float) -> torch.Tensor:
"""Differentiable demodulation from symbols to soft bit probabilities"""
pass
Usage Example
# Gradient-enabled modulation pipeline
modulator = QPSKModulator()
demodulator = QPSKDemodulator()
# Soft bits with gradients (from neural network output)
soft_bits = torch.tensor([0.1, 0.9, 0.2, 0.8], requires_grad=True)
# Differentiable modulation
symbols = modulator.forward_soft(soft_bits)
# Apply channel effects
noisy_symbols = channel(symbols)
# Differentiable demodulation
decoded_soft_bits = demodulator.forward_soft(noisy_symbols, noise_var=0.1)
# Gradient flow for training
loss = loss_function(decoded_soft_bits, target_bits)
loss.backward() # Gradients flow through entire pipeline
Development Status
- PR Add differentiable modulation and demodulation methods for BPSK, QPSK… #10: Closed due to compatibility issues
- PR Add differentiable Gray-coded QAM modulation and demodulation #41: Current development branch
- Status: In progress
Success Criteria
- Gradient Flow: All differentiable operations correctly propagate gradients
- Numerical Accuracy: Soft bit interpretations are mathematically correct
- Performance: Acceptable overhead compared to standard operations
- Compatibility: No breaking changes to existing API
- Documentation: Clear examples and API documentation
Future Extensions
- Additional modulation schemes (8-PSK, 16-QAM, etc.)
- MIMO system support
- Performance optimizations
- Mixed precision training support
Related Work
Integration with vector quantization methods may be explored: https://github.com/lucidrains/vector-quantize-pytorch
Metadata
Metadata
Assignees
Labels
enhancementNew feature or requestNew feature or request