Fast Modulo Transform (FMT, 高速剰余変換), Number Theoretical Transform (NTT, 数論変換)
概要
ある程度の解説(ブログ): ATC001 - C問題: 高速フーリエ変換 (FMT解法)
mod \(P\) 上で \(\omega^i \not = 1\) \((0 < i < N)\), \(\omega^N = 1\) となるような \(\omega, P, N\) を用いて計算する。 この時、 \(P = A \times N + 1\) を満たす。
順変換では \(\displaystyle f_k = \sum_{i=0}^{N-1} a_i \omega^{ik}\) (mod \(P\)) を計算し、
逆変換では \(\displaystyle a_i = \frac{1}{N} \sum_{k=0}^{N-1} f_k \omega^{-ik}\) (mod \(P\)) を計算する。
計算量
\(O(N \log N)\)
実装
# FMT用のパラメタ
omega = 103
n = 2**18
P = 5880*n + 1
rev = pow(omega, P-2, P)
# バタフライ演算としてのbit反転処理
# in-placeで計算するために利用
def bit_reverse(d):
n = len(d)
ns = n>>1; nss = ns>>1
ns1 = ns + 1
i = 0
for j in range(0, ns, 2):
if j<i:
d[i], d[j] = d[j], d[i]
d[i+ns1], d[j+ns1] = d[j+ns1], d[i+ns1]
d[i+1], d[j+ns] = d[j+ns], d[i+1]
k = nss; i ^= k
while k > i:
k >>= 1; i ^= k
return d
# FMTをループで計算
def fmt_bu(A, n, base, half, Q):
N = n
m = 1
while n>1:
n >>= 1
# ω^{2m} ≡ 1 となるω
w = pow(base, n, Q)
wk = 1
for j in range(m):
for i in range(j, N, 2*m):
# U = g(ω^{2k}), V = ω^k * h(ω^{2k})
U = A[i]; V = (A[i+m]*wk) % Q
A[i] = (U + V) % Q
# half = ω^{N/2}
A[i+m] = (U + V*half) % Q
wk = (wk * w) % Q
m <<= 1
return A
# FMTの順変換
def fmt(f, l, Q=P):
if l == 1: return f
A = f[:]
# bit反転
bit_reverse(A)
return fmt_bu(A, n, omega, pow(omega, n//2, Q), Q)
# FMTの逆変換
def ifmt(F, l, Q=P):
if l == 1: return F
A = F[:]
# bit反転
bit_reverse(A)
# 逆変換なので、ωの代わりにω^{-1}を渡す
f = fmt_bu(A, n, rev, pow(rev, n//2, Q), Q)
# Nで割って返す
n_rev = pow(n, Q-2, Q)
return [(e * n_rev) % Q for e in f]
# FMTを利用した畳込み処理
def convolute(a, b, l, Q=P):
A = fmt(a, l, Q)
B = fmt(b, l, Q)
C = [(s * t) % Q for s, t in zip(A, B)]
c = ifmt(C, l, Q)
return c
Verified
-
AtCoder: "AtCoder Typical Contest 001 - C問題: 高速フーリエ変換": source (Python2, 3704ms)
実装 (2次元)
計算量は \(O(N^2 \log N)\)
omega = 139
n = 2**9
P = 1359*n + 1
rev = pow(omega, P-2, P)
def bit_reverse(d):
n = len(d)
ns = n >> 1
nss = ns >> 1
ns1 = ns + 1
i = 0
for j in range(0, ns, 2):
if j<i:
d[i], d[j] = d[j], d[i]
d[i+ns1], d[j+ns1] = d[j+ns1], d[i+ns1]
d[i+1], d[j+ns] = d[j+ns], d[i+1]
k = nss; i ^= k
while k > i:
k >>= 1; i ^= k
return d
def fmt_bu(A, n, base, half, Q):
N = n
m = 1
while n > 1:
n >>= 1
w = pow(base, n, Q)
wk = 1
for j in range(m):
for i in range(j, N, 2*m):
U = A[i]; V = (A[i+m]*wk) % Q
A[i] = (U + V) % Q
A[i+m] = (U + V*half) % Q
wk = (wk * w) % Q
m <<= 1
return A
def fmt2d(f, l, Q=P):
tmp = [None]*n
for i in range(n):
A = list(f[i])
bit_reverse(A)
tmp[i] = fmt_bu(A, n, omega, pow(omega, n//2, Q), Q)
*tmp, = zip(*tmp)
F = [None]*n
for i in range(n):
A = list(tmp[i])
bit_reverse(A)
F[i] = fmt_bu(A, n, omega, pow(omega, n//2, Q), Q)
*F, = zip(*F)
return F
def ifmt2d(F, l, Q=P):
tmp = [None]*n
for i in range(n):
A = list(F[i])
bit_reverse(A)
tmp[i] = fmt_bu(A, n, rev, pow(omega, n//2, Q), Q)
*tmp, = zip(*tmp)
f = [None]*n
for i in range(n):
A = list(tmp[i])
bit_reverse(A)
f[i] = fmt_bu(A, n, rev, pow(omega, n//2, Q), Q)
*f, = zip(*f)
n2_rev = pow(n*n, (Q-2), Q)
F1 = [[(e * n2_rev) % Q for e in fi] for fi in f]
return F1
def convolute2d(a, b, l, Q=P):
A = fmt2d(a, l, Q)
B = fmt2d(b, l, Q)
C = [[(s * t) % Q for s, t in zip(Ai, Bi)] for Ai, Bi in zip(A, B)]
c = ifmt2d(C, l, Q)
return c
Verified
-
AOJ: "2977 - Bombing": source (Python3, 5.11sec)