[Track 1-B] Add where_backward operators#3582
Open
AdamPlatin123 wants to merge 14 commits into
Open
Conversation
…kward operator Triton kernel using atomic_add for gradient accumulation at pad boundaries. Simpler than reflection_pad (clamp indexing instead of modulo).
Uses @pointwise_dynamic pattern for automatic type promotion and non-contiguous support. Formula: grad * (1 if x > 0 else exp(x/alpha)) Changes: - src/flag_gems/ops/celu.py: Add celu_backward_kernel and celu_backward function - Register in ops/__init__.py and flag_gems/__init__.py - Add test_celu_backward.py with POINTWISE_SHAPES and FLOAT_DTYPES - Add benchmark/test_celu_backward.py
PReLU backward with support for both scalar and per-channel weights. Formulas: - grad_input = grad_output * (1 if x >= 0 else weight) - grad_weight = sum(grad_output * x for x < 0) per channel Uses Triton kernels with atomic_add for grad_weight accumulation. Changes: - src/flag_gems/ops/prelu_backward.py: New file with prelu_backward function - Register in ops/__init__.py and flag_gems/__init__.py - Add test_prelu_backward.py with scalar and per-channel tests - Add benchmark/test_prelu_backward.py
…kward operator 3D replication padding backward with clamp indexing and atomic_add accumulation. Formula: For each output position (d,h,w), map to input position using clamp: - id = clamp(od - pad_f, 0, D_in - 1) - ih = clamp(oh - pad_t, 0, H_in - 1) - iw = clamp(ow - pad_l, 0, W_in - 1) - atomic_add(grad_input[id, ih, iw], grad_output) Supports both 5D (N,C,D,H,W) and 4D (C,D,H,W) inputs. Changes: - src/flag_gems/ops/replication_pad3d_backward.py: New file (220 lines) - Register in ops/__init__.py and flag_gems/__init__.py - Add test_replication_pad3d_backward.py with multiple padding configs - Add benchmark/test_replication_pad3d_backward.py
…perator Uses @pointwise_dynamic pattern for automatic type promotion and non-contiguous support. Formula: grad_input = grad_output * (1 if |x| > lambd else 0) Softshrink forward: f(x) = x - lambd if x > lambd, x + lambd if x < -lambd, 0 otherwise. Changes: - src/flag_gems/ops/softshrink_backward.py: New file (@pointwise_dynamic, 18 lines) - Register in ops/__init__.py and flag_gems/__init__.py - Add test_softshrink_backward.py with POINTWISE_SHAPES and FLOAT_DTYPES - Add benchmark/test_softshrink_backward.py
…rink_backward operators - hardshrink: forward operator using @pointwise_dynamic - hardshrink_backward: backward operator using @pointwise_dynamic - Tests and benchmarks for both operators Formula: - forward: y = x if |x| > lambd else 0 - backward: grad_input = grad_output if |x| > lambd else 0
…operator - hardsigmoid_backward using @pointwise_dynamic - Test and benchmark files Formula: - hardsigmoid: y = clamp(x/6 + 0.5, 0, 1) - backward: grad_input = grad_output * (1/6) when -3 < x < 3, else 0
…maximum_backward operators
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment
Add this suggestion to a batch that can be applied as a single commit.This suggestion is invalid because no changes were made to the code.Suggestions cannot be applied while the pull request is closed.Suggestions cannot be applied while viewing a subset of changes.Only one suggestion per line can be applied in a batch.Add this suggestion to a batch that can be applied as a single commit.Applying suggestions on deleted lines is not supported.You must change the existing code in this line in order to create a valid suggestion.Outdated suggestions cannot be applied.This suggestion has been applied or marked resolved.Suggestions cannot be applied from pending reviews.Suggestions cannot be applied on multi-line comments.Suggestions cannot be applied while the pull request is queued to merge.Suggestion cannot be applied right now. Please check back later.
Summary
Implement where_backward operators for Track 1-B (Operator Coverage).
Implementation
Both operators use @pointwise_dynamic decorator for automatic type promotion and non-contiguous tensor support.
Formulas
where.Self_backward:
where.Other_backward:
The gradients flow only to the branch that was selected in the forward pass.
Files Added
Registration
Testing
FlagGems Operator Development Competition