Skip to content

[Track 1-B] Add where_backward operators#3582

Open
AdamPlatin123 wants to merge 14 commits into
flagos-ai:masterfrom
AdamPlatin123:competition/where-backward
Open

[Track 1-B] Add where_backward operators#3582
AdamPlatin123 wants to merge 14 commits into
flagos-ai:masterfrom
AdamPlatin123:competition/where-backward

Conversation

@AdamPlatin123
Copy link
Copy Markdown

Summary

Implement where_backward operators for Track 1-B (Operator Coverage).

Implementation

Both operators use @pointwise_dynamic decorator for automatic type promotion and non-contiguous tensor support.

Formulas

where.Self_backward:

  • where(cond, self, other) = self if cond else other
  • backward for self: grad = grad_output if cond else 0

where.Other_backward:

  • where(cond, self, other) = self if cond else other
  • backward for other: grad = 0 if cond else grad_output

The gradients flow only to the branch that was selected in the forward pass.

Files Added

  • src/flag_gems/ops/where_backward.py - Two operator implementations (27 lines)
  • tests/test_where_backward.py - Accuracy tests for both variants
  • benchmark/test_where_backward.py - Performance benchmarks for both variants

Registration

  • Registered in ops/init.py
  • Registered in flag_gems/init.py

Testing

  • Pointwise shapes coverage
  • All floating point dtypes (float32/bfloat16/float16)
  • Both Self_backward and Other_backward variants tested

FlagGems Operator Development Competition

…kward operator

Triton kernel using atomic_add for gradient accumulation at pad boundaries.
Simpler than reflection_pad (clamp indexing instead of modulo).
Uses @pointwise_dynamic pattern for automatic type promotion and non-contiguous support.

Formula: grad * (1 if x > 0 else exp(x/alpha))

Changes:

- src/flag_gems/ops/celu.py: Add celu_backward_kernel and celu_backward function

- Register in ops/__init__.py and flag_gems/__init__.py

- Add test_celu_backward.py with POINTWISE_SHAPES and FLOAT_DTYPES

- Add benchmark/test_celu_backward.py
PReLU backward with support for both scalar and per-channel weights.

Formulas:

- grad_input = grad_output * (1 if x >= 0 else weight)

- grad_weight = sum(grad_output * x for x < 0) per channel

Uses Triton kernels with atomic_add for grad_weight accumulation.

Changes:

- src/flag_gems/ops/prelu_backward.py: New file with prelu_backward function

- Register in ops/__init__.py and flag_gems/__init__.py

- Add test_prelu_backward.py with scalar and per-channel tests

- Add benchmark/test_prelu_backward.py
…kward operator

3D replication padding backward with clamp indexing and atomic_add accumulation.

Formula: For each output position (d,h,w), map to input position using clamp:

- id = clamp(od - pad_f, 0, D_in - 1)

- ih = clamp(oh - pad_t, 0, H_in - 1)

- iw = clamp(ow - pad_l, 0, W_in - 1)

- atomic_add(grad_input[id, ih, iw], grad_output)

Supports both 5D (N,C,D,H,W) and 4D (C,D,H,W) inputs.

Changes:

- src/flag_gems/ops/replication_pad3d_backward.py: New file (220 lines)

- Register in ops/__init__.py and flag_gems/__init__.py

- Add test_replication_pad3d_backward.py with multiple padding configs

- Add benchmark/test_replication_pad3d_backward.py
…perator

Uses @pointwise_dynamic pattern for automatic type promotion and non-contiguous support.

Formula: grad_input = grad_output * (1 if |x| > lambd else 0)

Softshrink forward: f(x) = x - lambd if x > lambd, x + lambd if x < -lambd, 0 otherwise.

Changes:

- src/flag_gems/ops/softshrink_backward.py: New file (@pointwise_dynamic, 18 lines)

- Register in ops/__init__.py and flag_gems/__init__.py

- Add test_softshrink_backward.py with POINTWISE_SHAPES and FLOAT_DTYPES

- Add benchmark/test_softshrink_backward.py
…rink_backward operators

- hardshrink: forward operator using @pointwise_dynamic
- hardshrink_backward: backward operator using @pointwise_dynamic
- Tests and benchmarks for both operators

Formula:
- forward: y = x if |x| > lambd else 0
- backward: grad_input = grad_output if |x| > lambd else 0
…operator

- hardsigmoid_backward using @pointwise_dynamic
- Test and benchmark files

Formula:
- hardsigmoid: y = clamp(x/6 + 0.5, 0, 1)
- backward: grad_input = grad_output * (1/6) when -3 < x < 3, else 0
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Projects

None yet

Development

Successfully merging this pull request may close these issues.

1 participant