Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
228 commits
Select commit Hold shift + click to select a range
5a6dc4b
Initial commit: Merge DaCeML into DaCe
affifboudaoud May 3, 2024
3c5219b
Merge branch 'master' of https://github.com/spcl/dace into autodiff
affifboudaoud May 3, 2024
0a5721c
store-all initial implementation
affifboudaoud May 7, 2024
7f9ca01
structural changes and formatting
affifboudaoud May 7, 2024
d7b4c64
yapf refactoring
affifboudaoud May 7, 2024
32d8761
store all implementation for non-base-level nodes
affifboudaoud May 8, 2024
9fd821f
fix store-all implementation
affifboudaoud May 13, 2024
02d0d7e
recompute: initial implementation
affifboudaoud May 14, 2024
60f3ea8
recomputation: small fixes
affifboudaoud May 14, 2024
e8c6933
Initial support for multistate SDFGs
affifboudaoud May 25, 2024
579c18a
[incomplete] Initial support for multistate AD
affifboudaoud Jun 3, 2024
68e1cd1
[Incomplete] additional features for multistate support
affifboudaoud Jun 5, 2024
18df683
Fix ONNX model loading for duplicated inputs/outputs
affifboudaoud Jun 13, 2024
dad803b
Added AD for NPBench code
affifboudaoud Jun 25, 2024
24caf25
[in progress] Added fix for conditional array assignment
affifboudaoud Jun 26, 2024
a2475ad
remove log sdfgs from commit
affifboudaoud Jun 26, 2024
cdb4a74
[in progress] Conditional tasklet support
affifboudaoud Jun 27, 2024
5973196
[in progress] Improved tasklet reversal to treat conditional tasklets
affifboudaoud Jun 30, 2024
71b034c
[in progress] Conditional tasklets and store strategy improvements
affifboudaoud Jul 1, 2024
8f6024f
Additional fixes to storing and removed unnecessary constraints
affifboudaoud Jul 2, 2024
2e326af
Merge branch 'master' of https://github.com/spcl/dace into loops
affifboudaoud Jul 3, 2024
a11da41
[In progress] Added support for CFG loop representation
affifboudaoud Jul 5, 2024
8c50074
[In progress] Additional improvements in LoopRegion support
affifboudaoud Jul 8, 2024
c34ade8
Merge branch 'master' of https://github.com/spcl/dace into loops_merge
affifboudaoud Jul 8, 2024
7b981b5
Merge branch 'master' of https://github.com/spcl/dace into loops_merge
affifboudaoud Jul 9, 2024
0df8ac9
Fixed reversal of NestedSDFGs
affifboudaoud Jul 9, 2024
ec29076
Fixed reversal of NestedSDFGs
affifboudaoud Jul 9, 2024
e875d1c
[In progress] Refactoring storing strategy
affifboudaoud Jul 12, 2024
256bcb9
[In progress] Further improvements to storing
affifboudaoud Aug 5, 2024
b5019c6
Initial implementation of recomputation
affifboudaoud Aug 9, 2024
43f35df
[In Progress] Initial implementation of ILP formulation
affifboudaoud Oct 9, 2024
3f0ddb1
[In Progress] Added conditional support to the ILP
affifboudaoud Oct 9, 2024
500e5cb
Improvements of ILP solution
affifboudaoud Nov 18, 2024
b1cec69
Merge remote-tracking branch 'upstream/main' into loops_merge
affifboudaoud Nov 18, 2024
31cfe10
Additional Fixes to loop gradient accumulation
affifboudaoud Nov 20, 2024
53c6eb1
Added required dependencies for DaCeML
affifboudaoud Nov 20, 2024
4fcfb11
Fixed storing with symbolic size case
affifboudaoud Jan 18, 2025
1868df6
Fixes to unnecessary squeezing in BLAS library nodes
phschaad Jan 20, 2025
2ffb84c
Added gradient reinitialization for overwritten forward arrays
affifboudaoud Jan 20, 2025
30bed74
Additional fixes to memlet creation in gradient clearing
affifboudaoud Jan 20, 2025
5188065
Fixed loop check in CCS extraction
affifboudaoud Jan 20, 2025
9c711fb
Fixes to array removal, removing non-redundant slices
phschaad Jan 21, 2025
fa98fba
Additional fixes to gradient clearing and reversal of maps with no in…
affifboudaoud Jan 21, 2025
cdfd289
Preserve reshapes for SDFG inlining when used in library nodes
phschaad Jan 21, 2025
bfa15f4
Ensure redundant array does not remove reshapes for library nodes
phschaad Jan 21, 2025
0383d1e
Removing casting from tasklets to be able to use sympy diff
affifboudaoud Jan 22, 2025
3fbc7fd
Merge remote-tracking branch 'upstream/main' into gradient-accumulation
affifboudaoud Jan 22, 2025
7706db4
Improvements to storing within decreasing loops
affifboudaoud Jan 22, 2025
357549a
Fix for storing views that point to NestedSDFGs
affifboudaoud Jan 23, 2025
9cf093e
Code generation fixes
tbennun Jan 24, 2025
7e2ec23
Pre-conditional-assignement fix
affifboudaoud Jan 27, 2025
9d0b09d
Multiple fixes for NPBench AD
affifboudaoud Jan 29, 2025
fb24b24
Cleanup + small fixes including allowing runtime-size allocation
affifboudaoud Feb 1, 2025
89cd1f1
Decide on MatMul expansion using unsqueezed shape
affifboudaoud Feb 4, 2025
a13fd30
Added back to libnode implementation for DOT
affifboudaoud Feb 11, 2025
01d331a
Merge branch 'users/afif/gemm_fixes_no_squeeze' of https://github.com…
affifboudaoud Feb 11, 2025
a446579
Various fixes to gradient clearing and auto-opt for the backward pass
affifboudaoud Feb 16, 2025
0c336c0
Inline ConditionalBlocks and add inter-state assignement edges
affifboudaoud Feb 17, 2025
cb7fabf
Trying to fix Softmax
affifboudaoud Feb 21, 2025
fa487c8
Add max reduction and clear multiple writes
affifboudaoud Feb 23, 2025
fbe4562
NPBench specific fixes
affifboudaoud Jun 23, 2025
824becf
Merge Snitch changes
affifboudaoud Jun 25, 2025
f03f622
Added SciPy to dependencies + formatting
affifboudaoud Jun 25, 2025
008026c
Add pip dependency
and-ivanov Jun 27, 2025
aa4d836
Make daceml tests discoverable
and-ivanov Jun 27, 2025
40e2338
fix frontend test
and-ivanov Jun 30, 2025
0a9f71e
Fix paths
and-ivanov Jun 30, 2025
988e230
Remove obsolete test that ensures only floating point computations ar…
and-ivanov Jun 30, 2025
f693f9b
Remove obsolete single state test
affifboudaoud Jun 30, 2025
b9812f5
Merge branch 'update_efforts' of https://github.com/affifboudaoud/dac…
affifboudaoud Jun 30, 2025
bc314d7
Remove deprecated np.bool
affifboudaoud Jun 30, 2025
636d1ea
Fix API usage in test SDFGBackwardRunner
and-ivanov Jun 30, 2025
d1ade4a
Fix conv implementation for default strides and pads
and-ivanov Jun 30, 2025
884d015
Fix ONNX operator expansions and their tests
and-ivanov Jul 1, 2025
a0dce78
Only use descriptor names to get AD data
affifboudaoud Jul 1, 2025
946d1b9
Added simplify to test_nested to avoid FunctionCallRegions
affifboudaoud Jul 1, 2025
25e3cc1
Added backward pass for Min reduction
affifboudaoud Jul 2, 2025
9353d90
Remove init state transformation and test since we support multiple s…
affifboudaoud Jul 3, 2025
431493b
Update bfs api
affifboudaoud Jul 3, 2025
baecfeb
Forward data if seprate_sdfgs
affifboudaoud Jul 3, 2025
f9e3ab6
Formatting + switch back to CPU
affifboudaoud Jul 3, 2025
71a9761
Formatting + check for signature if seprate_sdfgs
affifboudaoud Jul 3, 2025
c2751ad
Fix for axes_arr of shape 1 + Formatting
affifboudaoud Jul 3, 2025
1c6cf08
Multiple fixes to tests + Added main calls for debugging
affifboudaoud Jul 3, 2025
da7fe58
Make compatible with latest onnxruntime
and-ivanov Jul 7, 2025
ba9600c
Initialize unused arguments + Formatting
affifboudaoud Jul 8, 2025
794cde0
Fix test_input_outputs.py
and-ivanov Jul 8, 2025
9d49d6d
fix expansions
and-ivanov Jul 8, 2025
e3e8b03
fix test_bert.py: express softmax expansion in terms of simpler onnx …
and-ivanov Jul 9, 2025
7b104f3
resolve onnx/onnxruntime versioning issues
and-ivanov Jul 11, 2025
30e8394
Cleanup onnxruntime use
and-ivanov Jul 14, 2025
0998b0b
fix test_shared_input_output.py
and-ivanov Jul 14, 2025
59dc707
Fix expansions for Add,Sub,Mul,Div
and-ivanov Jul 15, 2025
0e2e7d3
fix test_conv2d.py
and-ivanov Jul 15, 2025
afa8256
Changed test dtype to float64
affifboudaoud Jul 15, 2025
bebe5de
Fixed tensors_close print order
affifboudaoud Jul 15, 2025
e09cd82
Fix AccessSets analysis api call
affifboudaoud Jul 15, 2025
6078c77
Formatting with yapf
affifboudaoud Jul 15, 2025
0f5861a
Fix BackwardPass node creation and validation
affifboudaoud Jul 15, 2025
3056c71
Remove FuncitonCallRegions before AD and initialize containers to zero
affifboudaoud Jul 15, 2025
361746b
Fix batchnorm implementation
and-ivanov Jul 16, 2025
70dc170
Express GlobalAveragePool through ReduceMean
and-ivanov Jul 16, 2025
4c39da5
Iterate over state views not loop state views
affifboudaoud Jul 16, 2025
0d5d613
Merge branch 'update_efforts' of https://github.com/affifboudaoud/dac…
affifboudaoud Jul 16, 2025
5193e4f
Refactor ReduceMax,Min,Sum,Mean and fix ambiguity in passing scalars …
and-ivanov Jul 16, 2025
dd580c1
Add Llama Decoder inference test
affifboudaoud Jul 17, 2025
dd6a534
Add LlamaForCausalLM test
affifboudaoud Jul 17, 2025
9e79460
Fix inf initialization + increase size limit for arrays
affifboudaoud Jul 17, 2025
1a25959
Add new tensorproto format
affifboudaoud Jul 31, 2025
4f3a5c1
Add new pure implementations + formatting
affifboudaoud Jul 31, 2025
39f6899
Fix initialization for constant arrays that need to be forwarded to t…
affifboudaoud Jul 31, 2025
21f1c12
Fix initialize_outputs_code call
affifboudaoud Aug 1, 2025
2d85ec9
Added wcr sum to einsum backward output and fixed einsum expansion in…
affifboudaoud Aug 6, 2025
9a75a94
Remove debug code
affifboudaoud Aug 6, 2025
630771d
Add Llama decoder backward test
affifboudaoud Aug 6, 2025
4edeaaa
Additional fixes to inf initializations
affifboudaoud Aug 6, 2025
5a3e5a8
Add support for indirection
affifboudaoud Aug 6, 2025
2f01229
Add initialization for integer tensors
affifboudaoud Aug 7, 2025
0eff710
Remove constant inputs when constructing ONNX op replacements
affifboudaoud Aug 7, 2025
511d9cd
Avoid gradient tracking for ONNX op attributes
affifboudaoud Aug 7, 2025
2488b01
Fix ReduceSum backward implementation
affifboudaoud Aug 7, 2025
4cbac69
Remove debug code
affifboudaoud Aug 7, 2025
33e7e88
Enable ONNX simplify by default
affifboudaoud Aug 7, 2025
447d6ba
Fix ReduceMax backward implementation
affifboudaoud Aug 8, 2025
461a674
Add register storage ONNX codegen
affifboudaoud Aug 8, 2025
741b0b1
Fix ReduceMean reduction conditions
affifboudaoud Aug 8, 2025
1e8a1ca
Remove unnecessary wcr sum check
affifboudaoud Aug 8, 2025
92357d2
Add pure BatchNorm implementation
affifboudaoud Aug 8, 2025
573d753
Remove size limit for arrays
affifboudaoud Aug 8, 2025
c0ff565
Add ninja dependency and limit ONNX to 1.17
affifboudaoud Aug 8, 2025
7f678ab
Add specific SDFG names to avoid folder mismatch with pytest
affifboudaoud Aug 8, 2025
ecced32
Avoid simplifying models for now
affifboudaoud Aug 8, 2025
a580723
Remove unused imports
affifboudaoud Aug 8, 2025
5c068c8
Fix Einsum expansion to avoid duplicate descriptors
affifboudaoud Aug 8, 2025
89846a3
Fix LayerNormalization backward implmenetation
affifboudaoud Aug 12, 2025
c69426a
Add full Llama backward test
affifboudaoud Aug 12, 2025
452e446
Add pure ReduceSum implementation + Extend ReduceMean
affifboudaoud Aug 12, 2025
a71d04e
Fix LayerNormalization reduction axes
affifboudaoud Aug 12, 2025
243b6d7
Remove obsolete tests and transformations
affifboudaoud Aug 14, 2025
2200199
Update ORT C API and raw bindings
affifboudaoud Aug 14, 2025
4215071
Set constant attributes for ONNX nodes
affifboudaoud Aug 14, 2025
9d7dbe9
Improve tests by verifying all gradients + increase batch size
affifboudaoud Aug 15, 2025
f2143b8
Multiple fixes to reduction axes in pure expansions
affifboudaoud Aug 15, 2025
965b327
Attempting to fix ORT C API
affifboudaoud Aug 15, 2025
aae9d13
Remove unnecessary views + obsolete GPU schedule code
affifboudaoud Aug 20, 2025
c977bce
Remove CPP implementations and improve softmax
affifboudaoud Aug 20, 2025
ac72044
Remove old Pow implementation
affifboudaoud Aug 21, 2025
314402c
Fix forwarded value non-zero initialization
affifboudaoud Aug 21, 2025
217fcfa
Remove debug SDFG save
affifboudaoud Aug 22, 2025
86147c4
Add zero initializations
affifboudaoud Aug 22, 2025
9f09de1
Add CopyToMap for GPU pass
affifboudaoud Aug 22, 2025
ea3c0b7
Merge remote-tracking branch 'origin/main' into dace_ad
affifboudaoud Aug 22, 2025
d2ad6b9
Merge lefover
affifboudaoud Aug 22, 2025
f04c534
Adapt to new API from merge
affifboudaoud Aug 22, 2025
cbdcc3d
Removed seprate dir for NPBench AD and added AD test prototype to k2mm
affifboudaoud Sep 10, 2025
78fe0e5
Added AD NPBench tests
affifboudaoud Sep 11, 2025
bc52b50
Add expand operator and default value for steps in Slice
affifboudaoud Sep 11, 2025
4b10931
Add all AD NPBench tests
affifboudaoud Sep 24, 2025
5898e30
Fix gradient clearing
affifboudaoud Sep 24, 2025
4137b5d
Formatting
affifboudaoud Sep 24, 2025
052d82b
Remove obsolete tests
affifboudaoud Sep 27, 2025
8fef9c6
Minor changes to tests + Formatting
affifboudaoud Sep 27, 2025
c02a43d
Formatting
affifboudaoud Sep 27, 2025
0ee5b94
Avoid DDE in constant folding + Formatting
affifboudaoud Sep 27, 2025
8045b02
Fix boolean tensor initialization
affifboudaoud Sep 27, 2025
2ccb46d
Add Dropout forward impl + Formatting
affifboudaoud Sep 27, 2025
40daa99
Fixes to BatchNorm + Formatting
affifboudaoud Sep 27, 2025
8dd6764
Disable auto-opt by default
affifboudaoud Sep 27, 2025
6bd42af
Add hooks before function init
affifboudaoud Sep 27, 2025
e2701ad
Formatting
affifboudaoud Sep 27, 2025
84f46e0
Formatting
affifboudaoud Sep 27, 2025
31ee8cb
Check for FunctionCallRegion in autodiff analysis
affifboudaoud Sep 27, 2025
a55aa67
Gradient clearing for single value arrays + Isolated node removal
affifboudaoud Sep 27, 2025
7974ac2
Fix codegen for Indices subsets
affifboudaoud Sep 27, 2025
1ebd677
Set transformers version to 4.5
affifboudaoud Sep 27, 2025
daa3a74
Remove GPU test for now
affifboudaoud Sep 27, 2025
30078b5
Remove GPU tests for now
affifboudaoud Sep 27, 2025
614fc0b
Remove unnecessary fixtures and remaining GPU tests
affifboudaoud Sep 28, 2025
2b492bf
Restructure tests and add onnx marker
affifboudaoud Sep 28, 2025
085d4d4
Update pytest marker
affifboudaoud Sep 28, 2025
725e8d6
Remove AD auto-opt until transformed into passes
affifboudaoud Sep 28, 2025
b80d2b4
Remove ONNXRuntime dependency
affifboudaoud Sep 28, 2025
3e87eb9
Remove AD auto-opt
affifboudaoud Sep 29, 2025
7f4a561
[Restructuring] Moved functions to utils and removed experimental dyn…
affifboudaoud Sep 29, 2025
acf0def
Seprate SDFG element reversal from generator
affifboudaoud Sep 29, 2025
2973c30
Separate more functions to utils and dace_nodes
affifboudaoud Sep 29, 2025
703b1f3
[Restructuring] Moved storing and recomputation strategies into own dir
affifboudaoud Sep 29, 2025
faa472e
Fix typo
affifboudaoud Sep 29, 2025
84dd7fc
Improve documentation
affifboudaoud Sep 29, 2025
afe1ecc
Remove unnecessary ONNXRuntime backend
affifboudaoud Sep 30, 2025
a8f411a
Remove onnx reporter
affifboudaoud Sep 30, 2025
e4ef648
Remove unnecessary testing dir
affifboudaoud Sep 30, 2025
365ff71
Add design documents for each module
affifboudaoud Sep 30, 2025
40c323d
Improve tests error messages and formatting
affifboudaoud Sep 30, 2025
f11ec23
Better documentation
affifboudaoud Sep 30, 2025
fcc14a5
Fix assertion in dlpack test
affifboudaoud Sep 30, 2025
f22dc61
Add comments
affifboudaoud Sep 30, 2025
1b8546a
Make sure to compare to dace gradients when testing
affifboudaoud Sep 30, 2025
595c755
Remove OpenBLAS dependency
affifboudaoud Oct 1, 2025
d65ff92
Add midding test packages
affifboudaoud Oct 1, 2025
e4eb08c
Merge remote-tracking branch 'origin/main' into dace_ad
affifboudaoud Oct 1, 2025
8975401
Add missing package + Formatting
affifboudaoud Oct 1, 2025
6168e9f
Allow Python 3.13 and ONNX 1.18
affifboudaoud Oct 1, 2025
1fd631f
Set onnx IR version explicitly
affifboudaoud Oct 1, 2025
073292a
Pre-commit formatting
affifboudaoud Oct 1, 2025
fdc0e3f
Serialization fixes
affifboudaoud Oct 1, 2025
4d0792b
Fix paths for cpp extensions
affifboudaoud Oct 1, 2025
877833a
Unique auto_opt name and expansion edge case
affifboudaoud Oct 2, 2025
4a69aa9
Skip some AD tests until serialization issue is fixed
affifboudaoud Oct 2, 2025
2b5d29a
Revert to main code
affifboudaoud Oct 2, 2025
57168e0
Remove conda specific import
affifboudaoud Oct 2, 2025
37030b1
Use expanded sdfgs instead of function call
affifboudaoud Oct 2, 2025
f096ad0
Make Torch and ONNX dependencies optional
affifboudaoud Oct 2, 2025
747e1e5
Update CI installation
affifboudaoud Oct 2, 2025
1bd2154
Update all CI installations
affifboudaoud Oct 2, 2025
8b73bf5
Avoid conflicting names got batch size in MKL implementation
affifboudaoud Oct 3, 2025
fafc460
Build Torch module in unique dir to avoid baton issues
affifboudaoud Oct 3, 2025
2aa4e3a
Attempting to reduce CI runtime with smaller sizes
affifboudaoud Oct 3, 2025
d2d31b9
Simplify durbin test
affifboudaoud Oct 3, 2025
f4b2c7c
Simplify resent
affifboudaoud Oct 3, 2025
1dc4d8b
Even smaller sizes for cavity_flow
affifboudaoud Oct 3, 2025
ad34aba
Avoid data race in loop lifiting test
affifboudaoud Oct 3, 2025
f07a6e6
Formatting
affifboudaoud Oct 3, 2025
477738b
Set JAX version to avoid conflict with cupy
affifboudaoud Oct 3, 2025
dec4703
set JAX to <= 0.6.2
affifboudaoud Oct 3, 2025
341d3b3
Smaller inputs for Cholesky
affifboudaoud Oct 3, 2025
21a2538
Make ReplacementTransformation abstract to pass coverage tests
affifboudaoud Oct 3, 2025
c998f1e
Remove redundant ReverseReduceMax class
affifboudaoud Oct 7, 2025
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion .github/workflows/copilot-setup-steps.yml
Original file line number Diff line number Diff line change
Expand Up @@ -35,6 +35,6 @@ jobs:

- name: Install DaCe in development mode
run: |
python -m pip install --editable ".[testing,linting]"
python -m pip install --editable ".[testing,linting,ml]"
pre-commit install
pre-commit run
2 changes: 1 addition & 1 deletion .github/workflows/fpga-ci.yml
Original file line number Diff line number Diff line change
Expand Up @@ -32,7 +32,7 @@ jobs:
python -m pip install --upgrade pip
pip install pytest-xdist flake8 coverage click
pip uninstall -y dace
pip install -e ".[testing]"
pip install -e ".[testing,ml]"
curl -Os https://uploader.codecov.io/latest/linux/codecov
chmod +x codecov

Expand Down
2 changes: 1 addition & 1 deletion .github/workflows/general-ci.yml
Original file line number Diff line number Diff line change
Expand Up @@ -38,7 +38,7 @@ jobs:
sudo apt-get install -y verilator # RTL simulation dependencies
python -m pip install --upgrade pip
pip install flake8 pytest-xdist coverage
pip install -e ".[testing]"
pip install -e ".[testing,ml]"
curl -Os https://uploader.codecov.io/latest/linux/codecov
chmod +x codecov

Expand Down
2 changes: 1 addition & 1 deletion .github/workflows/gpu-ci.yml
Original file line number Diff line number Diff line change
Expand Up @@ -36,7 +36,7 @@ jobs:
pip install mpi4py
pip install cupy
pip uninstall -y dace
pip install -e ".[testing]"
pip install -e ".[testing,ml]"
curl -Os https://uploader.codecov.io/latest/linux/codecov
chmod +x codecov

Expand Down
2 changes: 1 addition & 1 deletion .github/workflows/hardware_test.yml
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,7 @@ jobs:
python -m pip install --upgrade pip
pip install pytest-xdist flake8
pip uninstall -y dace
pip install -e ".[testing]"
pip install -e ".[testing,ml]"

- name: Run FPGA Tests
run: |
Expand Down
2 changes: 1 addition & 1 deletion .github/workflows/heterogeneous-ci.yml
Original file line number Diff line number Diff line change
Expand Up @@ -36,7 +36,7 @@ jobs:
pip install flake8 pytest-xdist coverage
pip install mpi4py pytest-mpi
pip uninstall -y dace
pip install -e ".[testing]"
pip install -e ".[testing,ml]"
curl -Os https://uploader.codecov.io/latest/linux/codecov
chmod +x codecov

Expand Down
57 changes: 57 additions & 0 deletions dace/autodiff/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,57 @@
"""
DaCe Automatic Differentiation (AD) System.

This module provides reverse-mode automatic differentiation for DaCe programs,
enabling automatic computation of gradients for optimized numerical kernels.

Main Components
---------------
- **add_backward_pass**: Main entry point for adding backward pass to an SDFG
- **BackwardPassGenerator**: Core algorithm for generating backward passes
- **BackwardImplementation**: ABC for implementing operation-specific backward rules
- **BackwardContext**: Context information for backward pass generation
- **BackwardResult**: Result of backward pass generation with forward/backward SDFGs
- **AutoDiffException**: Base exception for autodiff errors

Key Features
------------
- Support for control flow (loops, conditionals)
- Data forwarding strategies (store vs recompute tradeoffs)
- Extensible backward implementations for library nodes
- Integration with PyTorch autograd
- Automatic memory management for intermediate values


"""

from .base_abc import BackwardImplementation, BackwardContext, BackwardResult, AutoDiffException
from .backward_pass_generator import BackwardPassGenerator
from .autodiff import add_backward_pass

try:
from .torch import make_backward_function
TORCH_INTEGRATION_AVAILABLE = True
except ImportError:
make_backward_function = None
TORCH_INTEGRATION_AVAILABLE = False

import sys
from . import library

__all__ = [
# Main API
"add_backward_pass",
# Core classes
"BackwardPassGenerator",
"BackwardContext",
"BackwardResult",
# Extension points
"BackwardImplementation",
# Exceptions
"AutoDiffException",
# Submodules
"library",
]

if TORCH_INTEGRATION_AVAILABLE:
__all__.append("make_backward_function")
96 changes: 96 additions & 0 deletions dace/autodiff/analysis.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,96 @@
"""
Analysis helpers for autodiff
"""
from typing import Dict, Set, Tuple, Optional
import collections

import networkx as nx

from dace.sdfg import SDFG, SDFGState, nodes, utils as sdfg_utils
from dace.transformation.passes import analysis
from dace.sdfg.state import FunctionCallRegion

AccessSets = Dict[SDFGState, Tuple[Set[str], Set[str]]]


def dependency_analysis(sdfg: SDFG) -> Dict[str, Set[str]]:
"""
Analyze read dependencies of arrays in the SDFG.

:param sdfg: SDFG to analyze
:returns: A dictionary mapping array names to a list of read dependencies.
"""

# FIXME can be made more efficient
dependencies = nx.DiGraph()
for sdfg_node in sdfg.nodes():
if isinstance(sdfg_node, SDFGState):
for node in sdfg_node.data_nodes():
for edge in sdfg_node.edge_bfs(node, reverse=True):
dependencies.add_edge(node.data, edge.data.data)
elif isinstance(sdfg_node, FunctionCallRegion):
for state in sdfg_node.nodes():
assert isinstance(state, SDFGState)
for node in state.data_nodes():
for edge in state.edge_bfs(node, reverse=True):
dependencies.add_edge(node.data, edge.data.data)

dependencies = nx.transitive_closure(dependencies)
result = {}
for array in dependencies:
result[array] = {nbr for nbr in dependencies.neighbors(array)}
return result


def inverse_reachability(sdfg: SDFG) -> Dict[SDFGState, Set[SDFGState]]:

reachability = analysis.StateReachability().apply_pass(sdfg, {})
inverse_reachability = collections.defaultdict(set)
# iterate over cfg_ids
for cfg_id in reachability.keys():
for pred, successors in reachability[cfg_id].items():
for successor in successors:
inverse_reachability[successor].add(pred)

return inverse_reachability


def is_previously_written(sdfg: SDFG,
state: SDFGState,
node: nodes.Node,
array_name: str,
access_sets: Optional[AccessSets] = None) -> bool:
"""
Determine whether the given array name was written before the current node.

:param sdfg: the sdfg containing the node
:param state: the state containing the node
:param node: the node to check
:param array_name: the array name to check
:returns: True if the array was written before the node, False otherwise.
"""

if access_sets is None:
access_sets = analysis.AccessSets().apply_pass(sdfg, {})

reachable = inverse_reachability(sdfg)

# check the current state
for subgraph in sdfg_utils.concurrent_subgraphs(state):
if node in subgraph.nodes():
# this is our current subgraph, check if it was written before in this subgraph
for edge in state.edge_bfs(node, reverse=True):
if edge.data.data == array_name:
return True
else:
# this is not our current subgraph, check the write states
_, write_set = subgraph.read_and_write_sets()
if array_name in write_set:
return True

# check other states
for predecessor in reachable[state]:
_, write_set = access_sets[predecessor]
if array_name in write_set:
return True
return False
Loading
Loading