Skip to content

[Track 1-B] Add relu6_backward operator#3577

Open
AdamPlatin123 wants to merge 11 commits into
flagos-ai:masterfrom
AdamPlatin123:competition/relu6-backward
Open

[Track 1-B] Add relu6_backward operator#3577
AdamPlatin123 wants to merge 11 commits into
flagos-ai:masterfrom
AdamPlatin123:competition/relu6-backward

Conversation

@AdamPlatin123
Copy link
Copy Markdown

Summary

Implement relu6_backward operator for Track 1-B (Operator Coverage).

Implementation

Uses @pointwise_dynamic decorator for automatic type promotion and non-contiguous tensor support.

Formula

  • relu6(x) = min(max(x, 0), 6)
  • backward: grad_input = grad_output if 0 < x < 6, else 0

The gradient only flows through values in the active range (0, 6).

Files Added

  • src/flag_gems/ops/relu6_backward.py - Operator implementation (16 lines)
  • tests/test_relu6_backward.py - Accuracy tests
  • benchmark/test_relu6_backward.py - Performance benchmark

Registration

  • Registered in ops/init.py
  • Registered in flag_gems/init.py

Testing

  • Pointwise shapes coverage
  • All floating point dtypes (float32/bfloat16/float16)

FlagGems Operator Development Competition

…kward operator

Triton kernel using atomic_add for gradient accumulation at pad boundaries.
Simpler than reflection_pad (clamp indexing instead of modulo).
Uses @pointwise_dynamic pattern for automatic type promotion and non-contiguous support.

Formula: grad * (1 if x > 0 else exp(x/alpha))

Changes:

- src/flag_gems/ops/celu.py: Add celu_backward_kernel and celu_backward function

- Register in ops/__init__.py and flag_gems/__init__.py

- Add test_celu_backward.py with POINTWISE_SHAPES and FLOAT_DTYPES

- Add benchmark/test_celu_backward.py
PReLU backward with support for both scalar and per-channel weights.

Formulas:

- grad_input = grad_output * (1 if x >= 0 else weight)

- grad_weight = sum(grad_output * x for x < 0) per channel

Uses Triton kernels with atomic_add for grad_weight accumulation.

Changes:

- src/flag_gems/ops/prelu_backward.py: New file with prelu_backward function

- Register in ops/__init__.py and flag_gems/__init__.py

- Add test_prelu_backward.py with scalar and per-channel tests

- Add benchmark/test_prelu_backward.py
…kward operator

3D replication padding backward with clamp indexing and atomic_add accumulation.

Formula: For each output position (d,h,w), map to input position using clamp:

- id = clamp(od - pad_f, 0, D_in - 1)

- ih = clamp(oh - pad_t, 0, H_in - 1)

- iw = clamp(ow - pad_l, 0, W_in - 1)

- atomic_add(grad_input[id, ih, iw], grad_output)

Supports both 5D (N,C,D,H,W) and 4D (C,D,H,W) inputs.

Changes:

- src/flag_gems/ops/replication_pad3d_backward.py: New file (220 lines)

- Register in ops/__init__.py and flag_gems/__init__.py

- Add test_replication_pad3d_backward.py with multiple padding configs

- Add benchmark/test_replication_pad3d_backward.py
…perator

Uses @pointwise_dynamic pattern for automatic type promotion and non-contiguous support.

Formula: grad_input = grad_output * (1 if |x| > lambd else 0)

Softshrink forward: f(x) = x - lambd if x > lambd, x + lambd if x < -lambd, 0 otherwise.

Changes:

- src/flag_gems/ops/softshrink_backward.py: New file (@pointwise_dynamic, 18 lines)

- Register in ops/__init__.py and flag_gems/__init__.py

- Add test_softshrink_backward.py with POINTWISE_SHAPES and FLOAT_DTYPES

- Add benchmark/test_softshrink_backward.py
…rink_backward operators

- hardshrink: forward operator using @pointwise_dynamic
- hardshrink_backward: backward operator using @pointwise_dynamic
- Tests and benchmarks for both operators

Formula:
- forward: y = x if |x| > lambd else 0
- backward: grad_input = grad_output if |x| > lambd else 0
…operator

- hardsigmoid_backward using @pointwise_dynamic
- Test and benchmark files

Formula:
- hardsigmoid: y = clamp(x/6 + 0.5, 0, 1)
- backward: grad_input = grad_output * (1/6) when -3 < x < 3, else 0
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Projects

None yet

Development

Successfully merging this pull request may close these issues.

1 participant