Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions .scalafmt.conf
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@ danglingParentheses.preset = true
rewrite.rules = [AvoidInfix, SortImports, RedundantParens, SortModifiers]
docstrings = JavaDoc
newlines.afterCurlyLambda = preserve
newlines.beforeMultiline = keep
docstrings.style = keep
docstrings.oneline = unfold

Expand Down
68 changes: 33 additions & 35 deletions src/main/scala/io/github/acl4s/Convolution.scala
Original file line number Diff line number Diff line change
@@ -1,5 +1,7 @@
package io.github.acl4s

import io.github.acl4s.internal.foreach

import StaticModInt as ModInt

def convolution[M <: Int](
Expand Down Expand Up @@ -54,24 +56,20 @@ private def convolutionFft[M <: Int](
val z = 1 << internal.ceilPow2(n + m - 1)

val _a = java.util.Arrays.copyOf(a, z)
(n until z).foreach(i => { _a(i) = ModInt() })
foreach(n until z)(i => { _a(i) = ModInt() })
butterfly(_a)

val _b = java.util.Arrays.copyOf(b, z)
(m until z).foreach(i => { _b(i) = ModInt() })
foreach(m until z)(i => { _b(i) = ModInt() })
butterfly(_b)

for (i <- 0 until z) {
_a(i) *= _b(i)
}
foreach(0 until z)(i => { _a(i) *= _b(i) })
butterflyInv(_a)

val ans = java.util.Arrays.copyOf(_a, n + m - 1)
(z until n + m - 1).foreach(i => { ans(i) = ModInt() })
foreach(z until (n + m - 1))(i => { ans(i) = ModInt() })
val iz = ModInt(z).inv
for (i <- 0.until(n + m - 1)) {
ans(i) *= iz
}
foreach(0 until (n + m - 1))(i => { ans(i) *= iz })
ans
}

Expand All @@ -86,29 +84,29 @@ private def butterfly[M <: Int](a: Array[ModInt[M]])(using m: Modulus[M]): Unit
if (h - len == 1) {
val p = 1 << (h - len - 1)
val rot = ModInt(1)
for (s <- 0 until (1 << len)) {
foreach(0 until (1 << len))(s => {
val offset = s << (h - len)
for (i <- 0 until p) {
foreach(0 until p)(i => {
val l = a(i + offset)
val r = a(i + offset + p) * rot
a(i + offset) = l + r
a(i + offset + p) = l - r
}
})
if (s + 1 != (1 << len)) {
rot *= info.rate2(java.lang.Integer.numberOfTrailingZeros(~s))
}
}
})
len += 1
} else {
// 4-base
val p = 1 << (h - len - 2)
val rot = ModInt(1)
val imag = info.root(2)
for (s <- 0 until (1 << len)) {
foreach(0 until (1 << len))(s => {
val rot2 = rot * rot
val rot3 = rot2 * rot
val offset = s << (h - len)
for (i <- 0 until p) {
foreach(0 until p)(i => {
val mod2 = m.value.toLong * m.value
val a0 = a(i + offset).value.toLong
val a1 = a(i + offset + p).value.toLong * rot.value
Expand All @@ -120,11 +118,11 @@ private def butterfly[M <: Int](a: Array[ModInt[M]])(using m: Modulus[M]): Unit
a(i + offset + 1 * p) = ModInt(a0 + a2 + (2 * mod2 - (a1 + a3)))
a(i + offset + 2 * p) = ModInt(a0 + na2 + a1na3imag)
a(i + offset + 3 * p) = ModInt(a0 + na2 + (mod2 - a1na3imag))
}
})
if (s + 1 != (1 << len)) {
rot *= info.rate3(java.lang.Integer.numberOfTrailingZeros(~s))
}
}
})
len += 2
}
}
Expand All @@ -141,29 +139,29 @@ private def butterflyInv[M <: Int](a: Array[ModInt[M]])(using m: Modulus[M]): Un
if (len == 1) {
val p = 1 << (h - len)
val iRot = ModInt(1)
for (s <- 0 until (1 << (len - 1))) {
foreach(0 until (1 << (len - 1)))(s => {
val offset = s << (h - len + 1)
for (i <- 0 until p) {
foreach(0 until p)(i => {
val l = a(i + offset)
val r = a(i + offset + p)
a(i + offset) = l + r
a(i + offset + p) = (l - r + ModInt(m.value)) * iRot
}
})
if (s + 1 != (1 << (len - 1))) {
iRot *= info.iRate2(java.lang.Integer.numberOfTrailingZeros(~s))
}
}
})
len -= 1
} else {
// 4-base
val p = 1 << (h - len)
val iRot = ModInt(1)
val iImag = info.iRoot(2)
for (s <- 0 until (1 << (len - 2))) {
foreach(0 until (1 << (len - 2)))(s => {
val iRot2 = iRot * iRot
val iRot3 = iRot2 * iRot
val offset = s << (h - len + 2)
for (i <- 0 until p) {
foreach(0 until p)(i => {
val a0 = a(i + offset + 0 * p).value.toLong
val a1 = a(i + offset + 1 * p).value.toLong
val a2 = a(i + offset + 2 * p).value.toLong
Expand All @@ -175,11 +173,11 @@ private def butterflyInv[M <: Int](a: Array[ModInt[M]])(using m: Modulus[M]): Un
a(i + offset + 1 * p) = ModInt(a0 + (m.value - a1) + a2na3iImag) * iRot
a(i + offset + 2 * p) = ModInt(a0 + a1 + (m.value - a2) + (m.value - a3)) * iRot2
a(i + offset + 3 * p) = ModInt(a0 + (m.value - a1) + (m.value - a2na3iImag)) * iRot3
}
})
if (s + 1 != (1 << (len - 2))) {
iRot *= info.iRate3(java.lang.Integer.numberOfTrailingZeros(~s))
}
}
})
len -= 2
}
}
Expand All @@ -198,14 +196,14 @@ def convolutionLong(

import Convolution.*

assert(n + m - 1 <= (1 << MAX_AB_BIT))
require(n + m - 1 <= (1 << MAX_AB_BIT))

val c1 = convolutionLongMod[MOD1.type](a, b)
val c2 = convolutionLongMod[MOD2.type](a, b)
val c3 = convolutionLongMod[MOD3.type](a, b)

val c = new Array[Long](n + m - 1)
for (i <- 0.until(n + m - 1)) {
foreach(0 until (n + m - 1))(i => {
var x = 0L
x += (c1(i) * I1) % MOD1 * M2M3
x += (c2(i) * I2) % MOD2 * M1M3
Expand Down Expand Up @@ -234,7 +232,7 @@ def convolutionLong(
}
x -= Offset((diff % 5 /* == Offset.length */ ).toInt)
c(i) = x
}
})

c
}
Expand All @@ -250,7 +248,7 @@ private def convolutionLongMod[M <: Int](
}

val z = 1 << internal.ceilPow2(n + m - 1)
assert((mod.value - 1) % z == 0)
require((mod.value - 1) % z == 0)

val a2 = a.map(ModInt(_))
val b2 = b.map(ModInt(_))
Expand All @@ -274,30 +272,30 @@ final class FftInfo[M <: Int] private (using m: Modulus[M]) {

root(rank2) = ModInt(g).pow((m.value - 1) >> rank2)
iRoot(rank2) = root(rank2).inv
for (i <- (rank2 - 1) to 0 by -1) {
foreach((rank2 - 1) to 0 by -1)(i => {
root(i) = root(i + 1) * root(i + 1)
iRoot(i) = iRoot(i + 1) * iRoot(i + 1)
}
})

{
val prod = ModInt(1)
val iProd = ModInt(1)
for (i <- 0 to (rank2 - 2)) {
foreach(0 to (rank2 - 2))(i => {
rate2(i) = root(i + 2) * prod
iRate2(i) = iRoot(i + 2) * iProd
prod *= iRoot(i + 2)
iProd *= root(i + 2)
}
})
}
{
val prod = ModInt(1)
val iProd = ModInt(1)
for (i <- 0 to (rank2 - 3)) {
foreach(0 to (rank2 - 3))(i => {
rate3(i) = root(i + 3) * prod
iRate3(i) = iRoot(i + 3) * iProd
prod *= iRoot(i + 3)
iProd *= root(i + 3)
}
})
}
}
object FftInfo {
Expand Down
20 changes: 11 additions & 9 deletions src/main/scala/io/github/acl4s/Dsu.scala
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,8 @@ package io.github.acl4s

import scala.collection.mutable

import io.github.acl4s.internal.foreach

/**
* Implement (union by size) + (path compression)
* Reference:
Expand All @@ -19,8 +21,8 @@ final class Dsu(private val n: Int) {
private val parentOrSize: Array[Int] = Array.fill(n)(-1)

def merge(a: Int, b: Int): Int = {
assert(0 <= a && a < n)
assert(0 <= b && b < n)
require(0 <= a && a < n)
require(0 <= b && b < n)
var x = leader(a)
var y = leader(b)
if (x == y) { return x }
Expand All @@ -36,13 +38,13 @@ final class Dsu(private val n: Int) {
}

def same(a: Int, b: Int): Boolean = {
assert(0 <= a && a < n)
assert(0 <= b && b < n)
require(0 <= a && a < n)
require(0 <= b && b < n)
leader(a) == leader(b)
}

def leader(a: Int): Int = {
assert(0 <= a && a < n)
require(0 <= a && a < n)
if (parentOrSize(a) < 0) {
a
} else {
Expand All @@ -52,23 +54,23 @@ final class Dsu(private val n: Int) {
}

def size(a: Int): Int = {
assert(0 <= a && a < n)
require(0 <= a && a < n)
-parentOrSize(leader(a))
}

def groups(): collection.Seq[collection.Seq[Int]] = {
val leaderBuf = new Array[Int](n)
val groupSize = new Array[Int](n)
(0 until n).foreach(i => {
foreach(0 until n)(i => {
leaderBuf(i) = leader(i)
groupSize(leaderBuf(i)) += 1
})

val result = new mutable.ArrayBuffer[mutable.Buffer[Int]](n)
(0 until n).foreach(i => {
foreach(0 until n)(i => {
result.addOne(new mutable.ArrayBuffer(groupSize(i)))
})
(0 until n).foreach(i => {
foreach(0 until n)(i => {
result(leaderBuf(i)) += i
})

Expand Down
8 changes: 4 additions & 4 deletions src/main/scala/io/github/acl4s/FenwickTree.scala
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@ package io.github.acl4s

import scala.reflect.ClassTag

import io.github.acl4s.internal.IPair
import io.github.acl4s.internal.{foreach, IPair}

/**
* Reference: https://en.wikipedia.org/wiki/Fenwick_tree
Expand All @@ -18,11 +18,11 @@ final class FenwickTree[T: ClassTag](

def this(array: Array[T])(using AddSub[T]) = {
this(array.length)
array.indices.foreach(i => add(i, array(i)))
foreach(array.indices)(i => { add(i, array(i)) })
}

def add(index: Int, x: T): Unit = {
assert(0 <= index && index < n)
require(0 <= index && index < n)
var p = index + 1
while (p <= n) {
data(p - 1) = m.combine(data(p - 1), x)
Expand All @@ -46,7 +46,7 @@ final class FenwickTree[T: ClassTag](
}

def sum(l: Int, r: Int): T = {
assert(0 <= l && l <= r && r <= n)
require(0 <= l && l <= r && r <= n)
m.subtract(sum(r), sum(l))
}
}
Expand Down
Loading