Skip to content

Commit 95cfd2d

Browse files
michalharakalclaude
andcommitted
Add pow / powScalar op + PowSpecializationPass (Tier A of #617)
Adds element-wise `pow(a, b)` and `powScalar(a, n)` to TensorOps, emits `stablehlo.power` from the HLO converter, and introduces PowSpecializationPass that rewrites `pow(x, 2)` to `multiply(x, x)` in the graph optimization pipeline (so the matmul / SIMD elementwise kernels do the work, not a real `pow` per element). Surfaces touched: - TensorOps interface — `@Diff` annotated `pow(a, b)` and `powScalar(a, n)`. - VoidTensorOps stubs. - DefaultCpuOps scalar impl with two arms: small-integer exponents (|n| ≤ 16) use repeated-multiply (exact); everything else routes through kotlin.math.pow. - PowOperation data class in TensorOperations.kt — same form supports binary (two tensor inputs) and scalar (single input + parameters ["scalar_exponent"]) shapes. - RecordingTensorOpsDecorator records both into PowOperation with the scalar value preserved in parameters for backward recovery. - Tensor.pow(Number) / Tensor.pow(Tensor) extensions (no operator form — Kotlin has no `**`). - BasicMathConverter emits `stablehlo.power` for the binary form. - DefaultGradientTape has powBackward / powScalarBackward stub overrides returning null (real formulas land in Tier C alongside conv/pool backward). PowSpecializationPass currently specialises only n=2 (the most common case — RMSNorm/MSE/GELU all use squared); n=3+ is a follow-up. Registered in createDefault / createAggressive / createLLM pipelines after DTypeConstraintResolutionPass and before fusion so the multiply form propagates to fusion. JVM Vector-API specialisation deliberately skipped — sqrt / exp / abs are all scalar-only today; matches existing pattern. Tier A scope per the plan: 7 forward-parity tests + 4 specialisation- pass tests, all green locally. No regression on engine bench scenarios (none touch pow). Co-Authored-By: Claude Opus 4.7 (1M context) <noreply@anthropic.com>
1 parent a1fc274 commit 95cfd2d

13 files changed

Lines changed: 489 additions & 1 deletion

File tree

skainet-backends/skainet-backend-cpu/src/commonMain/kotlin/sk/ainet/exec/tensor/ops/DefaultCpuOps.kt

Lines changed: 61 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -12,6 +12,7 @@ import sk.ainet.lang.tensor.data.FloatArrayTensorData
1212
import sk.ainet.lang.tensor.data.TensorDataFactory
1313
import sk.ainet.lang.tensor.ops.UpsampleMode
1414
import sk.ainet.lang.types.FP32
15+
import kotlin.math.pow
1516
import kotlin.math.sqrt
1617

1718
@Backend(id = "cpu", displayName = "CPU")
@@ -2123,6 +2124,66 @@ public open class DefaultCpuOpsBase(protected val dataFactory: TensorDataFactory
21232124
return newTensor(outData, tensor.dtype, tensor)
21242125
}
21252126

2127+
/**
2128+
* Element-wise power: `c[i] = a[i] ^ b[i]`. Integer-valued exponents
2129+
* use repeated multiply for stability; everything else routes through
2130+
* `kotlin.math.pow`. Shape contract: shapes must match exactly (no
2131+
* broadcasting yet — caller's responsibility).
2132+
*/
2133+
override fun <T : DType, V> pow(a: Tensor<T, V>, b: Tensor<T, V>): Tensor<T, V> {
2134+
require(
2135+
a.dtype == sk.ainet.lang.types.FP32::class ||
2136+
a.dtype == sk.ainet.lang.types.FP16::class
2137+
) { "pow supports only FP16/FP32, got ${a.dtype}" }
2138+
require(a.shape == b.shape) { "pow requires matching shapes; got ${a.shape} and ${b.shape}" }
2139+
val outData = dataFactory.init<T, V>(a.shape, a.dtype) { idx ->
2140+
val av = a.data.get(*idx) as Float
2141+
val bv = b.data.get(*idx) as Float
2142+
@Suppress("UNCHECKED_CAST")
2143+
scalarPow(av, bv) as V
2144+
}
2145+
return newTensor(outData, a.dtype, a)
2146+
}
2147+
2148+
/**
2149+
* Element-wise scalar power: `c[i] = a[i] ^ n`. Small-integer
2150+
* exponents (|n| <= 16) use repeated multiply for exactness; all
2151+
* other values route through `kotlin.math.pow`.
2152+
*/
2153+
override fun <T : DType, V> powScalar(a: Tensor<T, V>, n: Number): Tensor<T, V> {
2154+
require(
2155+
a.dtype == sk.ainet.lang.types.FP32::class ||
2156+
a.dtype == sk.ainet.lang.types.FP16::class
2157+
) { "powScalar supports only FP16/FP32, got ${a.dtype}" }
2158+
val nFloat = n.toFloat()
2159+
val nInt = n.toInt()
2160+
val isSmallInt = nFloat == nInt.toFloat() && kotlin.math.abs(nInt) <= 16
2161+
val outData = dataFactory.init<T, V>(a.shape, a.dtype) { idx ->
2162+
val av = a.data.get(*idx) as Float
2163+
@Suppress("UNCHECKED_CAST")
2164+
(if (isSmallInt) integerPow(av, nInt) else scalarPow(av, nFloat)) as V
2165+
}
2166+
return newTensor(outData, a.dtype, a)
2167+
}
2168+
2169+
/** Repeated-multiply for small integer exponents. Handles n < 0 via reciprocal. */
2170+
private fun integerPow(base: Float, n: Int): Float {
2171+
if (n == 0) return 1f
2172+
if (n < 0) return 1f / integerPow(base, -n)
2173+
var result = 1f
2174+
var b = base
2175+
var e = n
2176+
while (e > 0) {
2177+
if (e and 1 == 1) result *= b
2178+
b *= b
2179+
e = e ushr 1
2180+
}
2181+
return result
2182+
}
2183+
2184+
private fun scalarPow(base: Float, exp: Float): Float =
2185+
base.toDouble().pow(exp.toDouble()).toFloat()
2186+
21262187
// ---- TinyFoA ops: abs, sign, clamp, lt, ge ----
21272188

21282189
@TensorOp()
Lines changed: 89 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,89 @@
1+
package sk.ainet.exec.tensor.ops
2+
3+
import kotlin.math.abs
4+
import kotlin.test.Test
5+
import kotlin.test.assertEquals
6+
import kotlin.test.assertFailsWith
7+
import kotlin.test.assertTrue
8+
import sk.ainet.lang.tensor.Shape
9+
import sk.ainet.lang.tensor.VoidOpsTensor
10+
import sk.ainet.lang.tensor.data.DenseTensorDataFactory
11+
import sk.ainet.lang.tensor.data.FloatArrayTensorData
12+
import sk.ainet.lang.types.FP32
13+
14+
/**
15+
* Forward-parity tests for the new `pow` and `powScalar` ops (Tier A
16+
* of #617). Checks both the binary form (tensor exponent) and the
17+
* scalar form for integer + real exponents.
18+
*/
19+
class DefaultCpuOpsPowTest {
20+
private val dataFactory = DenseTensorDataFactory()
21+
private val ops = DefaultCpuOps(dataFactory)
22+
23+
private fun floatTensor(shape: Shape, values: FloatArray) =
24+
VoidOpsTensor(dataFactory.fromFloatArray<FP32, Float>(shape, FP32::class, values), FP32::class)
25+
26+
private fun assertCloseTo(expected: FloatArray, actual: FloatArray, tol: Float = 1e-4f) {
27+
assertEquals(expected.size, actual.size, "length mismatch")
28+
for (i in expected.indices) {
29+
val diff = abs(expected[i] - actual[i])
30+
assertTrue(diff <= tol, "[$i] expected=${expected[i]} actual=${actual[i]} diff=$diff tol=$tol")
31+
}
32+
}
33+
34+
@Test
35+
fun powScalar_integer_2_matches_x_times_x() {
36+
val a = floatTensor(Shape(5), floatArrayOf(0.5f, 1f, 2f, 3f, -2f))
37+
val expected = floatArrayOf(0.25f, 1f, 4f, 9f, 4f)
38+
val out = ops.powScalar(a, 2)
39+
assertCloseTo(expected, (out.data as FloatArrayTensorData<*>).buffer)
40+
}
41+
42+
@Test
43+
fun powScalar_integer_3_matches_x_cubed() {
44+
val a = floatTensor(Shape(4), floatArrayOf(1f, 2f, 3f, -2f))
45+
val expected = floatArrayOf(1f, 8f, 27f, -8f)
46+
val out = ops.powScalar(a, 3)
47+
assertCloseTo(expected, (out.data as FloatArrayTensorData<*>).buffer)
48+
}
49+
50+
@Test
51+
fun powScalar_negative_integer_minus_1_is_reciprocal() {
52+
val a = floatTensor(Shape(3), floatArrayOf(2f, 4f, 0.5f))
53+
val expected = floatArrayOf(0.5f, 0.25f, 2f)
54+
val out = ops.powScalar(a, -1)
55+
assertCloseTo(expected, (out.data as FloatArrayTensorData<*>).buffer)
56+
}
57+
58+
@Test
59+
fun powScalar_real_half_is_sqrt() {
60+
val a = floatTensor(Shape(4), floatArrayOf(0f, 1f, 4f, 9f))
61+
val expected = floatArrayOf(0f, 1f, 2f, 3f)
62+
val out = ops.powScalar(a, 0.5f)
63+
assertCloseTo(expected, (out.data as FloatArrayTensorData<*>).buffer)
64+
}
65+
66+
@Test
67+
fun powScalar_real_1_5_matches_kotlin_math_pow() {
68+
val a = floatTensor(Shape(3), floatArrayOf(1f, 2f, 4f))
69+
val expected = floatArrayOf(1f, 2.828427f, 8f)
70+
val out = ops.powScalar(a, 1.5f)
71+
assertCloseTo(expected, (out.data as FloatArrayTensorData<*>).buffer)
72+
}
73+
74+
@Test
75+
fun pow_binary_element_wise() {
76+
val a = floatTensor(Shape(4), floatArrayOf(2f, 3f, 4f, 5f))
77+
val b = floatTensor(Shape(4), floatArrayOf(2f, 3f, 0.5f, 1f))
78+
val expected = floatArrayOf(4f, 27f, 2f, 5f)
79+
val out = ops.pow(a, b)
80+
assertCloseTo(expected, (out.data as FloatArrayTensorData<*>).buffer)
81+
}
82+
83+
@Test
84+
fun pow_binary_rejects_shape_mismatch() {
85+
val a = floatTensor(Shape(3), floatArrayOf(1f, 2f, 3f))
86+
val b = floatTensor(Shape(4), floatArrayOf(1f, 2f, 3f, 4f))
87+
assertFailsWith<IllegalArgumentException> { ops.pow(a, b) }
88+
}
89+
}

skainet-compile/skainet-compile-core/src/commonMain/kotlin/sk/ainet/tape/RecordingExecution.kt

Lines changed: 15 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -184,6 +184,21 @@ internal class RecordingTensorOpsDecorator(private val base: TensorOps) : Tensor
184184
return out
185185
}
186186

187+
// --- Power ops ---
188+
override fun <T : DType, V> pow(a: Tensor<T, V>, b: Tensor<T, V>): Tensor<T, V> {
189+
val out = base.pow(a, b)
190+
record(PowOperation<T, V>(), listOf(a, b), listOf(out))
191+
return out
192+
}
193+
194+
override fun <T : DType, V> powScalar(a: Tensor<T, V>, n: Number): Tensor<T, V> {
195+
val out = base.powScalar(a, n)
196+
// Single-input + scalar exponent stashed in parameters so the
197+
// backward formula can recover it (a-partial is n * a^(n-1)).
198+
record(PowOperation<T, V>(parameters = mapOf("scalar_exponent" to n)), listOf(a), listOf(out))
199+
return out
200+
}
201+
187202
// --- Scalar ops ---
188203
override fun <T : DType, V> addScalar(a: Tensor<T, V>, b: Number): Tensor<T, V> {
189204
val out = base.addScalar(a, b)

skainet-compile/skainet-compile-dag/src/commonMain/kotlin/sk/ainet/lang/graph/DefaultExecutionTape.kt

Lines changed: 14 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -643,6 +643,20 @@ public class DefaultGradientTape(
643643
return listOf(null, null, null)
644644
}
645645

646+
override fun powBackward(upstream: Tensor<DType, Any>, output: Tensor<DType, Any>, inputs: List<Tensor<DType, Any>>, attributes: Map<String, Any?>): List<Tensor<DType, Any>?> {
647+
// Backward for pow(a, b): da = b*a^(b-1)*upstream, db = a^b*log(a)*upstream.
648+
// Needs `log` op (Tier B of #617) for the db partial.
649+
// First-cut Tier A stub: return null for both partials. Real formula lands in Tier C.
650+
return listOf(null, null)
651+
}
652+
653+
override fun powScalarBackward(upstream: Tensor<DType, Any>, output: Tensor<DType, Any>, inputs: List<Tensor<DType, Any>>, attributes: Map<String, Any?>): List<Tensor<DType, Any>?> {
654+
// Backward for powScalar(a, n): da = n*a^(n-1)*upstream.
655+
// Self-contained (no log needed) — but defer the formula to Tier C
656+
// alongside the rest of the autograd completeness work.
657+
return listOf(null)
658+
}
659+
646660
override fun conv2dBackward(upstream: Tensor<DType, Any>, output: Tensor<DType, Any>, inputs: List<Tensor<DType, Any>>, attributes: Map<String, Any?>): List<Tensor<DType, Any>?> {
647661
// d(conv2d(x, w, b))/dx, d(conv2d(x, w, b))/dw, d(conv2d(x, w, b))/db
648662
// This is complex and usually implemented in the backend.

skainet-compile/skainet-compile-dag/src/commonTest/kotlin/sk/ainet/compile/graph/ComputeGraphExecutorTest.kt

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -177,6 +177,8 @@ private class TestTensorOps : TensorOps {
177177
override fun <T : DType, V> mean(tensor: Tensor<T, V>, dim: Int?): Tensor<T, V> = tensor
178178
override fun <T : DType, V> variance(tensor: Tensor<T, V>, dim: Int?): Tensor<T, V> = tensor
179179
override fun <T : DType, V> sqrt(tensor: Tensor<T, V>): Tensor<T, V> = tensor
180+
override fun <T : DType, V> pow(a: Tensor<T, V>, b: Tensor<T, V>): Tensor<T, V> = a
181+
override fun <T : DType, V> powScalar(a: Tensor<T, V>, n: Number): Tensor<T, V> = a
180182
override fun <T : DType, V> abs(tensor: Tensor<T, V>): Tensor<T, V> = tensor
181183
override fun <T : DType, V> sign(tensor: Tensor<T, V>): Tensor<T, V> = tensor
182184
override fun <T : DType, V> clamp(tensor: Tensor<T, V>, minVal: Float, maxVal: Float): Tensor<T, V> = tensor

skainet-compile/skainet-compile-hlo/src/commonMain/kotlin/sk/ainet/compile/hlo/converters/BasicMathConverter.kt

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -22,7 +22,8 @@ public class BasicMathConverter : StableHloOperationConverter {
2222

2323
override val supportedOperations: Set<String> = setOf(
2424
"add", "subtract", "multiply", "divide",
25-
"sub", "mul", "div" // Common aliases
25+
"sub", "mul", "div", // Common aliases
26+
"pow"
2627
)
2728

2829
override fun convert(
@@ -101,6 +102,7 @@ public class BasicMathConverter : StableHloOperationConverter {
101102
"subtract", "sub" -> "stablehlo.subtract"
102103
"multiply", "mul" -> "stablehlo.multiply"
103104
"divide", "div" -> "stablehlo.divide"
105+
"pow" -> "stablehlo.power"
104106
else -> null
105107
}
106108
}

skainet-compile/skainet-compile-opt/src/commonMain/kotlin/sk/ainet/compile/opt/GraphOptimizationPipeline.kt

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -6,6 +6,7 @@ import sk.ainet.compile.opt.passes.DTypeConstraintResolutionPass
66
import sk.ainet.compile.opt.passes.DeadCodeEliminationPass
77
import sk.ainet.compile.opt.passes.LLMFusionPass
88
import sk.ainet.compile.opt.passes.OperationFusionPass
9+
import sk.ainet.compile.opt.passes.PowSpecializationPass
910
import sk.ainet.compile.opt.passes.SharedWeightDeduplicationPass
1011
import sk.ainet.compile.opt.passes.TransposeEliminationPass
1112

@@ -80,6 +81,11 @@ public class GraphOptimizationPipeline(
8081
// is the boundary where dtype problems surface — every
8182
// later pass can assume dtype-validity.
8283
DTypeConstraintResolutionPass(),
84+
// Rewrite pow(x, 2) to multiply(x, x) before fusion so
85+
// the downstream passes see the multiply form. Runs after
86+
// dtype resolution (still benefits from resolved dtypes)
87+
// and before everything else.
88+
PowSpecializationPass(),
8389
DeadCodeEliminationPass(),
8490
ConstantFoldingPass(),
8591
OperationFusionPass()
@@ -92,6 +98,7 @@ public class GraphOptimizationPipeline(
9298
public fun createAggressive(): GraphOptimizationPipeline = GraphOptimizationPipeline(
9399
passes = listOf(
94100
DTypeConstraintResolutionPass(),
101+
PowSpecializationPass(),
95102
DeadCodeEliminationPass(),
96103
ConstantFoldingPass(),
97104
OperationFusionPass()
Lines changed: 120 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,120 @@
1+
package sk.ainet.compile.opt.passes
2+
3+
import sk.ainet.compile.opt.GraphOptimizationPass
4+
import sk.ainet.compile.opt.GraphOptimizationResult
5+
import sk.ainet.lang.graph.ComputeGraph
6+
import sk.ainet.lang.graph.GraphEdge
7+
import sk.ainet.lang.graph.GraphNode
8+
import sk.ainet.lang.tensor.ops.MultiplyOperation
9+
import sk.ainet.lang.tensor.ops.PowOperation
10+
11+
/**
12+
* Rewrites `powScalar(x, n)` for small integer `n` (currently `n == 2`)
13+
* into the equivalent `multiply(x, x)` chain. The downstream multiply
14+
* dispatch routes to the matmul / SIMD elementwise kernels — much
15+
* cheaper than a real `pow` per element.
16+
*
17+
* Pattern detected:
18+
* ```
19+
* PowOperation node with parameters["scalar_exponent"] == 2 and one input
20+
* ```
21+
* Replaced with:
22+
* ```
23+
* MultiplyOperation node with both inputs wired to the original input
24+
* ```
25+
*
26+
* Wider integer exponents (n = 3, 4, ...) intentionally not handled in
27+
* this first cut — each adds one more layer of multiplies and the
28+
* register-pressure / staging trade-off isn't obvious without a
29+
* benchmark. Add them when there's a workload that wants them.
30+
*/
31+
public class PowSpecializationPass : GraphOptimizationPass {
32+
33+
override val name: String = "pow-specialization"
34+
35+
override fun apply(graph: ComputeGraph): GraphOptimizationResult {
36+
val diagnostics = mutableListOf<String>()
37+
var changed = false
38+
39+
// Snapshot nodes — we mutate the graph inside the loop.
40+
val candidates = graph.nodes.filter { node ->
41+
node.operation is PowOperation<*, *> &&
42+
node.inputs.size == 1 &&
43+
exponentInt(node) == 2
44+
}
45+
46+
for (powNode in candidates) {
47+
val producer = graph.edges.firstOrNull { it.destination.id == powNode.id }
48+
?: continue
49+
val sourceNode = producer.source
50+
51+
// Build the replacement multiply node — same id so consumer
52+
// edges that target powNode.id continue to resolve.
53+
val mul = GraphNode(
54+
id = powNode.id,
55+
operation = MultiplyOperation<sk.ainet.lang.types.DType, Any>(),
56+
inputs = listOf(powNode.inputs[0], powNode.inputs[0]),
57+
outputs = powNode.outputs,
58+
metadata = powNode.metadata,
59+
)
60+
61+
// Snapshot edges before mutating.
62+
val incomingToPow = graph.edges.filter { it.destination.id == powNode.id }
63+
val outgoingFromPow = graph.edges.filter { it.source.id == powNode.id }
64+
65+
graph.removeNode(powNode)
66+
graph.addNode(mul)
67+
68+
// Wire both multiply inputs to the original x.
69+
for (i in 0..1) {
70+
graph.addEdge(
71+
GraphEdge(
72+
id = "e_${sourceNode.id}_${producer.sourceOutputIndex}__${mul.id}_$i",
73+
source = sourceNode,
74+
destination = mul,
75+
sourceOutputIndex = producer.sourceOutputIndex,
76+
destinationInputIndex = i,
77+
tensorSpec = producer.tensorSpec,
78+
),
79+
)
80+
}
81+
82+
// Restore the outgoing edges to the new node.
83+
for (edge in outgoingFromPow) {
84+
graph.addEdge(
85+
GraphEdge(
86+
id = edge.id,
87+
source = mul,
88+
destination = edge.destination,
89+
sourceOutputIndex = edge.sourceOutputIndex,
90+
destinationInputIndex = edge.destinationInputIndex,
91+
tensorSpec = edge.tensorSpec,
92+
),
93+
)
94+
}
95+
96+
// The old incoming edge to the (removed) pow node should be
97+
// cleaned up — removeNode usually does this, but defensively
98+
// remove the producer edge if it survived.
99+
for (edge in incomingToPow) {
100+
graph.removeEdge(edge)
101+
}
102+
103+
diagnostics += "Specialized pow(${sourceNode.id}, 2) -> multiply at node ${powNode.id}"
104+
changed = true
105+
}
106+
107+
return GraphOptimizationResult(graph, changed = changed, diagnostics = diagnostics)
108+
}
109+
110+
/**
111+
* Returns the integer exponent stashed in [PowOperation.parameters]
112+
* (under `"scalar_exponent"`), or `null` if absent / non-integer.
113+
*/
114+
private fun exponentInt(node: GraphNode): Int? {
115+
val raw = node.operation.parameters["scalar_exponent"] ?: return null
116+
val n = (raw as? Number)?.toDouble() ?: return null
117+
val asInt = n.toInt()
118+
return if (n == asInt.toDouble()) asInt else null
119+
}
120+
}

0 commit comments

Comments
 (0)