Skip to content

Conversation

affifboudaoud
Copy link
Collaborator

Pull Request: Machine Learning Integration for DaCe

Overview

This PR adds comprehensive machine learning capabilities to DaCe through three tightly integrated components:

  1. Automatic Differentiation (AD) - Reverse-mode gradient computation for SDFGs
  2. ONNX Integration - Import and execute neural network models
  3. PyTorch Integration - Bidirectional interoperability with PyTorch's autograd system

Together, these components enable DaCe to optimize and accelerate machine learning workloads, particularly neural network training and inference.

High-Level Architecture

PyTorch Model
     ↓
  ONNX Export
     ↓
DaCe SDFG (Forward)
     ↓
Automatic Differentiation
     ↓
DaCe SDFG (Backward)
     ↓
Compiled Code Generation
     ↓
PyTorch Operator (with Autograd)

Component 1: Automatic Differentiation (dace/autodiff/)

Purpose

Provides reverse-mode automatic differentiation for SDFGs, enabling gradient computation for any DaCe program. This is the foundation for neural network training and gradient-based optimization.

Key Capabilities

  • Full SDFG Support: Differentiates maps, tasklets, nested SDFGs, loops, and library nodes
  • Control Flow: Handles loops (LoopRegion) and conditionals
  • ONNX Operations: 50+ backward implementations for ONNX operators
  • Data Forwarding: Flexible strategies (store vs. recompute) for memory/compute tradeoffs
  • Extensible Registry: Plugin-based system for adding backward rules

Core Algorithm

  1. Forward Pass Execution: Run original computation and identify required intermediates
  2. Backward Pass Generation: Traverse computation graph in reverse, accumulating gradients
  3. Node Reversal: Each forward node (Map, Tasklet, ONNXOp) has a registered backward implementation
  4. Gradient Accumulation: Use write-conflict resolution (WCR) for multi-path gradients

Key Files

File Lines Purpose
backward_pass_generator.py ~800 Core AD engine that orchestrates backward pass generation
implementations/onnx_ops.py ~2000 Backward implementations for 50+ ONNX operations
implementations/dace_nodes.py ~600 Backward rules for core SDFG elements (Tasklet, Map, etc.)
data_forwarding/manager.py ~300 Store vs. recompute strategy coordination

Component 2: ONNX Integration (dace/libraries/onnx/)

Purpose

Enables importing and executing ONNX neural network models within DaCe. Converts ONNX graphs to optimized DaCe SDFGs for efficient execution on CPU/GPU.

Key Capabilities

  • Model Import: Load ONNX models from files or protobuf objects
  • 100+ Operations: Dynamically generated node classes for all ONNX ops
  • Shape Inference: Automatic symbolic and concrete shape computation
  • Multi-Strategy Implementations: Pure (correctness), optimized (performance), hardware-specific
  • Type Safety: Schema-based validation and type checking

Core Architecture

Dynamic Node Generation:

  • Registry system generates Python classes for all ONNX operations at import time
  • Each operation has schema, properties, connectors, and implementations
  • Example: ONNXConv, ONNXMatMul, ONNXSoftmax (100+ generated classes)

Implementation Strategies:

  1. Pure Implementations (pure_implementations.py): Reference implementations in Python/NumPy
  2. Optimized Implementations (img_op_implementations.py): Hand-crafted SDFGs for performance
  3. Hardware-Specific: Future GPU/FPGA specialized implementations

Import Pipeline:

ONNX Model → Validation → Shape Inference → Simplification → SDFG Construction → Compilation

Key Files

File Lines Purpose
onnx_importer.py 711 Main entry point, orchestrates import pipeline
op_implementations/pure_implementations.py 3052 Reference implementations for 40+ operations
nodes/onnx_op_registry.py 325 Dynamic node class generation
schema.py 390 Type system and validation
shape_inference/symbolic_shape_infer.py 1976 Symbolic shape inference (Microsoft-sourced)

Component 3: PyTorch Integration (dace/libraries/torch/)

Purpose

Provides bidirectional integration between PyTorch and DaCe. Enables optimizing PyTorch models with DaCe while maintaining PyTorch's autograd compatibility.

Key Capabilities

  • Model Optimization: Convert torch.nn.Module to optimized DaCe SDFGs
  • Autograd Integration: Backward pass generation integrates with PyTorch's autograd
  • Dual Dispatch: C++ extension (performance) or CTypes (flexibility)
  • Zero-Copy Tensors: DLPack protocol for efficient memory sharing
  • Training Support: Full forward + backward pass compilation

Core Architecture

Integration Flow:

PyTorch Model → ONNX Export → DaCe SDFG → Backward Generation → Compilation → PyTorch Operator

Dispatcher Strategies:

  1. C++ Extension (cpp_torch_extension.py): Native PyTorch operator with autograd
    • High performance
    • 64 parameter limit
    • Slower compilation
  2. CTypes Module (ctypes_module.py): Pure Python dispatcher
    • Unlimited parameters
    • Faster compilation
    • Slight overhead

Zero-Copy Memory Sharing:

  • DLPack protocol enables PyTorch tensors to view DaCe memory without copying
  • Bidirectional: DaCe → PyTorch (outputs) and PyTorch → DaCe (inputs)

Key Files

File Lines Purpose
dispatchers/cpp_torch_extension.py 717 C++ code generation for PyTorch operators
dispatchers/ctypes_module.py 230 CTypes-based dispatcher
dlpack.py 199 Zero-copy tensor sharing via DLPack
environments/pytorch_env.py 94 CMake build configuration

How Components Work Together

Example: Training a PyTorch Model with DaCe

import torch
from dace.frontend.python import DaceModule

# 1. Define PyTorch model
model = MyNeuralNetwork()
optimizer = torch.optim.Adam(model.parameters())

# 2. Wrap with DaCe (compiles on first call)
dace_model = DaceModule(model, dummy_inputs, backward=True)

# 3. Training loop (standard PyTorch code)
for inputs, labels in dataloader:
    optimizer.zero_grad()
    outputs = dace_model(inputs)  # DaCe-optimized forward pass
    loss = criterion(outputs, labels)
    loss.backward()  # DaCe-optimized backward pass
    optimizer.step()

What Happens Internally:

  1. First Call: PyTorch model → ONNX export → DaCe SDFG (via ONNX integration)
  2. Backward Generation: Forward SDFG → Backward SDFG (via autodiff)
  3. Compilation: Both SDFGs compiled to optimized code
  4. Dispatcher: C++ extension or CTypes wrapper created
  5. Forward Pass: DaCe executes optimized forward computation
  6. Backward Pass: DaCe executes generated backward computation
  7. Gradient Return: Gradients flow back to PyTorch optimizer

Data Flow

PyTorch Tensor (input)
    ↓ Zero-copy (DLPack)
DaCe Array
    ↓ Optimized SDFG Execution
DaCe Array (output)
    ↓ Zero-copy (DLPack)
PyTorch Tensor (output)
    ↓ loss.backward()
PyTorch Tensor (grad_output)
    ↓ Zero-copy (DLPack)
DaCe Array (backward pass input)
    ↓ Backward SDFG Execution
DaCe Array (grad_input)
    ↓ Zero-copy (DLPack)
PyTorch Tensor (grad_input)

Testing Strategy

Test Organization

tests/
├── autodiff/                       # AD correctness tests
│   ├── test_single_state.py        # Basic AD operations
│   └── torch/                      # PyTorch integration tests
│       ├── test_training.py        # End-to-end training
│       ├── test_bert_encoder_backward.py    # BERT model
│       └── test_llama_decoder_backward.py   # LLaMA model
│
├── onnx/                          # ONNX import tests
│   ├── test_python_frontend.py    # Basic operations
│   ├── test_bert_subgraphs.py     # Real model subgraphs
│   └── test_input_outputs.py      # I/O handling
│
└── torch/                          # PyTorch integration tests
│   ├── test_lenet.py               # Simple CNN
│   ├── test_bert_encoder.py        # Transformer encoder
│   └── test_llama_decoder.py       # Decoder architecture
│
└── npbench/                        # AD tests on NPBench kernels

Test Coverage

Component Test Files Coverage
Autodiff Core 15+ files Tasklets, maps, loops, nested SDFGs
ONNX Integration 20+ files Import, execution, type handling
PyTorch Integration 15+ files Forward, backward, training loops

Running Tests

# All basic tests (excluding hardware-specific)
pytest -m "(autodiff or torch or onnx) and not long" tests/

# AD tests only
pytest tests/autodiff/

# ONNX tests only
pytest tests/onnx/

# PyTorch tests only
pytest tests/torch/

Known Limitations and Future Work

Current Limitations

  1. Recompute Strategy: Experimental, not production-ready
  2. Control Flow: Conditionals are inlined into state machine (not reversed as ControlFlowRegions)
  3. Second-Order Gradients: Not yest tested

Documentation

Each component has detailed design documentation:

These documents provide:

  • Detailed component descriptions
  • Algorithm explanations
  • Code walkthrough
  • Extension points
  • Implementation notes

Impact on DaCe

Code Additions

Component Lines of Code Files
Autodiff ~8,000 15+ files
ONNX ~7,000 20+ files
PyTorch ~1,500 10+ files
Total ~16,500 45+ files

Dependencies

New dependencies (already in setup.py):

  • onnx - ONNX model format
  • onnxsim - ONNX graph simplification
  • torch - PyTorch framework (optional)
  • protobuf - Protocol buffers (for ONNX)
  • jax - For gradient numerical validation tests
    -transformers - For testing the Pytorch/ONNX frontends
  • efficientnet_pytorch- For testing EfficientNet

@affifboudaoud affifboudaoud added no-ci Do not run any CI or actions for this PR and removed no-ci Do not run any CI or actions for this PR labels Oct 1, 2025
@affifboudaoud affifboudaoud removed the no-ci Do not run any CI or actions for this PR label Oct 2, 2025
@affifboudaoud affifboudaoud marked this pull request as ready for review October 6, 2025 13:01
Removed the redundant ReverseReduceMax class and its methods, which duplicated functionality from ReverseReduce. Updated import statements and cleaned up the code.
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

5 participants