Skip to content

Autograd completeness: pow + log + conv/pool backward formulas (#617)#618

Merged
michalharakal merged 5 commits into
developfrom
feature/autograd-completeness
May 18, 2026
Merged

Autograd completeness: pow + log + conv/pool backward formulas (#617)#618
michalharakal merged 5 commits into
developfrom
feature/autograd-completeness

Conversation

@michalharakal
Copy link
Copy Markdown
Contributor

Closes #617.

Fills out DefaultGradientTape so every op a real CNN training loop touches has a backward formula. The gradient tape itself already existed (~900 LOC, 40+ formulas) — this PR adds the missing pieces. Five commits, four trackable tiers.

Summary

  • Tier A (95cfd2d) — pow / powScalar op across TensorOps, CPU backend, HLO emission (stablehlo.power), plus a PowSpecializationPass that rewrites pow(x, 2)multiply(x, x).
  • Tier B (f9747b0) — log / log2 / log10 op family (prerequisite for the full pow backward w.r.t. exponent).
  • Tier C.1 (6ca1bfe) — backward formulas for pow, powScalar, log, log2, log10 + dispatch arms.
  • Tier C.2/3 (b256394) — backward formulas for conv1d, conv2d, conv3d, maxPool2d, avgPool2d, upsample2d (Nearest), split. First-cut direct CPU loops — correctness over speed; perf path is a follow-up.
  • Tier D (412127e) — end-to-end CnnTrainingStepTest runs a conv→ReLU→maxPool→reshape→matmul network through one SGD step; loss decreases, every parameter gets a non-null grad.

Notable architecture decision

split needed special handling in recordTrace (one BackwardOp per output chunk, each scattering its upstream into a zero-filled input grad via scatterAlongDim). The standard BackwardOp(output=...) shape can't carry N upstream gradients, and reworking the tape framework would have been disproportionate — a localised special-case keeps the change small.

Test plan

  • Per-op finite-difference parity for every new backward (PowLogBackwardTest 6 tests, ConvPoolBackwardTest 8 tests). FP32 tol 1e-2 for elementwise, 3e-2 for conv.
  • End-to-end CNN training step (CnnTrainingStepTest) — loss decreases, all 4 trainable params get grads.
  • Regression: AutogradBasicTest, SkainetScopeTest still green.
  • Cross-module sweep green: :skainet-lang:skainet-lang-core:jvmTest, :skainet-backends:skainet-backend-cpu:jvmTest, :skainet-compile:skainet-compile-opt:jvmTest, :skainet-compile:skainet-compile-hlo:jvmTest, :skainet-compile:skainet-compile-dag:jvmTest.

Out of scope (follow-ups)

  • Higher-order gradients.
  • upsample2d bilinear backward (forward doesn't support bilinear yet).
  • Conv backward perf path — direct loops are O(N·C·K²·H·W); a tiled / SPI-routed version is a separate ticket.
  • Recording argmax during maxPool2d forward (the backward recomputes it — fine for correctness, slower than necessary).
  • Native FFM pow / log specialisations (waits on the native FFM provider).

🤖 Generated with Claude Code

michalharakal and others added 5 commits May 18, 2026 07:49
Adds element-wise `pow(a, b)` and `powScalar(a, n)` to TensorOps,
emits `stablehlo.power` from the HLO converter, and introduces
PowSpecializationPass that rewrites `pow(x, 2)` to `multiply(x, x)`
in the graph optimization pipeline (so the matmul / SIMD elementwise
kernels do the work, not a real `pow` per element).

Surfaces touched:
- TensorOps interface — `@Diff` annotated `pow(a, b)` and
  `powScalar(a, n)`.
- VoidTensorOps stubs.
- DefaultCpuOps scalar impl with two arms: small-integer exponents
  (|n| ≤ 16) use repeated-multiply (exact); everything else routes
  through kotlin.math.pow.
- PowOperation data class in TensorOperations.kt — same form supports
  binary (two tensor inputs) and scalar (single input + parameters
  ["scalar_exponent"]) shapes.
- RecordingTensorOpsDecorator records both into PowOperation with the
  scalar value preserved in parameters for backward recovery.
- Tensor.pow(Number) / Tensor.pow(Tensor) extensions (no operator
  form — Kotlin has no `**`).
- BasicMathConverter emits `stablehlo.power` for the binary form.
- DefaultGradientTape has powBackward / powScalarBackward stub
  overrides returning null (real formulas land in Tier C alongside
  conv/pool backward).

PowSpecializationPass currently specialises only n=2 (the most
common case — RMSNorm/MSE/GELU all use squared); n=3+ is a follow-up.
Registered in createDefault / createAggressive / createLLM
pipelines after DTypeConstraintResolutionPass and before fusion so
the multiply form propagates to fusion.

JVM Vector-API specialisation deliberately skipped — sqrt / exp /
abs are all scalar-only today; matches existing pattern.

Tier A scope per the plan: 7 forward-parity tests + 4 specialisation-
pass tests, all green locally. No regression on engine bench
scenarios (none touch pow).

Co-Authored-By: Claude Opus 4.7 (1M context) <noreply@anthropic.com>
Element-wise natural / base-2 / base-10 logarithms on TensorOps,
with scalar DefaultCpuOps impls routing to kotlin.math.ln/log2/log10
per element. Same dtype guard as sqrt (FP16/FP32 only).

Surfaces touched:
- TensorOps — `@Diff` annotated `log`, `log2`, `log10`.
- VoidTensorOps stubs.
- LogOperation / Log2Operation / Log10Operation data classes in
  TensorOperations.kt (single-input, shape-preserving).
- DefaultCpuOps scalar implementations.
- RecordingTensorOpsDecorator pass-through overrides (matches sqrt/abs
  pattern — KSP-generated wrapper handles tape recording).
- Tensor.log() / .log2() / .log10() extension functions.
- DefaultGradientTape logBackward / log2Backward / log10Backward
  stubs returning null. Real formulas land in Tier C:
    da/d(log a)   = upstream / a
    da/d(log2 a)  = upstream / (a * ln 2)
    da/d(log10 a) = upstream / (a * ln 10)

HLO emission: `log` is auto-wired via the existing UnaryMathConverter
("log" -> "stablehlo.log" was already in the opMap). `log2` and
`log10` deliberately NOT emitted — StableHLO has no native ops for
either, so a graph using them fails HLO compilation with a clean
"Unsupported" error. Lowering as `log(x) / ln(base)` is a small
follow-up.

JVM Vector-API specialisation skipped — consistent with sqrt/exp/abs
which are also scalar-only in the JVM backend today.

Tests: 7 forward parity tests covering canonical values, NaN/Inf
edge cases (log of negative -> NaN, log of zero -> -Inf), three-way
consistency (log2 = log/ln(2), log10 = log/ln(10)), and dtype-guard
rejection of Int32. All green locally.

Co-Authored-By: Claude Opus 4.7 (1M context) <noreply@anthropic.com>
powBackward, powScalarBackward, logBackward, log2Backward, log10Backward
replace the null stubs from Tiers A/B with real formulas. powScalar reads
KSP's "n" string attribute and falls back to the decorator's
"scalar_exponent" Number so both recording paths work.

PowLogBackwardTest verifies each formula against central finite-difference
(tol 1e-2 for FP32). Conv/pool/split backward still stubbed — next half
of Tier C.

Co-Authored-By: Claude Opus 4.7 (1M context) <noreply@anthropic.com>
Replaces the null stubs for conv1d/2d/3d, maxPool2d, avgPool2d,
upsample2d and split with first-cut direct CPU loops:

- conv{1,2,3}dGrads — closed-form dInput / dWeight / dBias from the
  forward windowing rule (ih = oh*sH - pH + kh*dH …). Groups,
  stride, padding, dilation handled.
- maxPool2dGrad — recomputes argmax per window and routes upstream
  there. Ties resolved to first encountered (matches forward order).
- avgPool2dGrad — distributes upstream across the window; divisor
  follows forward countIncludePad rule.
- upsample2dGrad — nearest-only, sums the upstream block above-left
  of each input pixel. Bilinear errors (forward doesn't support it).
- avgPool2d dispatch arm in buildBackwardFromTrace (was missing).

split needs N backwards (one per chunk) because BackwardOp carries one
output. recordTrace now special-cases "split" → registerSplitBackwards,
each chunk's backward scatters upstream into a zeros input grad via
scatterAlongDim; tape accumulation concats them.

ConvPoolBackwardTest exercises every new formula against central
finite-difference (tol 3e-2 for FP32 conv noise).

Co-Authored-By: Claude Opus 4.7 (1M context) <noreply@anthropic.com>
CnnTrainingStepTest builds a tiny conv2d → ReLU → maxPool2d → reshape →
matmul + bias network, records the forward pass, runs the tape backward,
applies one SGD step, and asserts loss doesn't increase. Confirms the
full forward+backward+optimiser path composes correctly with the new
backward formulas from Tier C — every trainable parameter (convW, convB,
linW, linB) receives a non-null gradient.

Closes the autograd-completeness work for #617. Cross-module regression
sweep (lang-core, cpu backend, compile-opt, compile-hlo, compile-dag)
green.

Co-Authored-By: Claude Opus 4.7 (1M context) <noreply@anthropic.com>
@michalharakal michalharakal merged commit 61c7fea into develop May 18, 2026
11 checks passed
@github-actions
Copy link
Copy Markdown

📖 Documentation Preview

The documentation has been built successfully for this PR.

Generated Files:

  • Operator documentation: docs/modules/operators/_generated_/
  • JSON schema output: operators.json

Artifacts:

  • Download the documentation-preview-618 artifact to view the complete documentation locally.

This comment will be updated automatically when the PR is updated.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

Add tensor pow op + StableHLO power parity + autograd completeness (conv/pool backward, log family)

1 participant