Description
NNlib
- activations
- tanh_fast
- sigmoid
- sigmoid_fast
- softmax!
- logsoftmax!
- conv
- conv! (feat: more coverage for NNlib functions #258)
- maxpool
- maxpool! (feat: more coverage for NNlib functions #258)
- meanpool
- meanpool! (feat: more coverage for NNlib functions #258)
- batched_mul
- batched_mul! (feat: more coverage for NNlib functions #258)
- padding
- pad_constant
- pad_circular
- pad_repeat
- pad_zeros
- pad_reflect
- pad_symmetric
- ∇conv_data (needed for ConvTranspose)
- ∇conv_data!
- ∇conv_filter
- ∇conv_filter!
- gather (partially addressed in feat: partial NNlib.gather support + better indexing support #252)
- gather! (partially addressed in feat: partial NNlib.gather support + better indexing support #252)
- scatter
- scatter!
-
dot_product_attention - pixel_shuffle
- batchnorm (has a special stablehlo impl)
- this needs to wait for a corresponding adjoint on EnzymeJAX end - grid_sample
- grid_sample!
Some of these might not even be needed, but we should test the generated IR. Strikethrough denotes we don't need specialized handling for these operations
Feel free to add missing operations to the list