Fast Modulo Transform (FMT, 高速剰余変換), Number Theoretical Transform (NTT, 数論変換)

概要

ある程度の解説(ブログ): ATC001 - C問題: 高速フーリエ変換 (FMT解法)

mod \(P\) 上で \(\omega^i \not = 1\) \((0 < i < N)\), \(\omega^N = 1\) となるような \(\omega, P, N\) を用いて計算する。 この時、 \(P = A \times N + 1\) を満たす。

順変換では \(\displaystyle f_k = \sum_{i=0}^{N-1} a_i \omega^{ik}\) (mod \(P\)) を計算し、

逆変換では \(\displaystyle a_i = \frac{1}{N} \sum_{k=0}^{N-1} f_k \omega^{-ik}\) (mod \(P\)) を計算する。

計算量

\(O(N \log N)\)

実装

# FMT用のパラメタ
omega = 103
n = 2**18
P = 5880*n + 1
rev = pow(omega, P-2, P)

# バタフライ演算としてのbit反転処理
# in-placeで計算するために利用
def bit_reverse(d):
    n = len(d)
    ns = n>>1; nss = ns>>1
    ns1 = ns + 1
    i = 0
    for j in range(0, ns, 2):
        if j<i:
            d[i], d[j] = d[j], d[i]
            d[i+ns1], d[j+ns1] = d[j+ns1], d[i+ns1]
        d[i+1], d[j+ns] = d[j+ns], d[i+1]
        k = nss; i ^= k
        while k > i:
            k >>= 1; i ^= k
    return d

# FMTをループで計算
def fmt_bu(A, n, base, half, Q):
    N = n
    m = 1
    while n>1:
        n >>= 1
        # ω^{2m} ≡ 1 となるω
        w = pow(base, n, Q)
        wk = 1
        for j in range(m):
            for i in range(j, N, 2*m):
                # U = g(ω^{2k}), V = ω^k * h(ω^{2k})
                U = A[i]; V = (A[i+m]*wk) % Q
                A[i] = (U + V) % Q
                # half = ω^{N/2}
                A[i+m] = (U + V*half) % Q
            wk = (wk * w) % Q
        m <<= 1
    return A

# FMTの順変換
def fmt(f, l, Q=P):
    if l == 1: return f
    A = f[:]
    # bit反転
    bit_reverse(A)
    return fmt_bu(A, n, omega, pow(omega, n//2, Q), Q)

# FMTの逆変換
def ifmt(F, l, Q=P):
    if l == 1: return F
    A = F[:]
    # bit反転
    bit_reverse(A)
    # 逆変換なので、ωの代わりにω^{-1}を渡す
    f = fmt_bu(A, n, rev, pow(rev, n//2, Q), Q)
    # Nで割って返す
    n_rev = pow(n, Q-2, Q)
    return [(e * n_rev) % Q for e in f]

# FMTを利用した畳込み処理
def convolute(a, b, l, Q=P):
    A = fmt(a, l, Q)
    B = fmt(b, l, Q)
    C = [(s * t) % Q for s, t in zip(A, B)]
    c = ifmt(C, l, Q)
    return c

Verified

  • AtCoder: "AtCoder Typical Contest 001 - C問題: 高速フーリエ変換": source (Python2, 3704ms)

実装 (2次元)

計算量は \(O(N^2 \log N)\)

omega = 139
n = 2**9
P = 1359*n + 1
rev = pow(omega, P-2, P)

def bit_reverse(d):
    n = len(d)
    ns = n >> 1
    nss = ns >> 1
    ns1 = ns + 1
    i = 0
    for j in range(0, ns, 2):
        if j<i:
            d[i], d[j] = d[j], d[i]
            d[i+ns1], d[j+ns1] = d[j+ns1], d[i+ns1]
        d[i+1], d[j+ns] = d[j+ns], d[i+1]
        k = nss; i ^= k
        while k > i:
            k >>= 1; i ^= k
    return d

def fmt_bu(A, n, base, half, Q):
    N = n
    m = 1
    while n > 1:
        n >>= 1
        w = pow(base, n, Q)
        wk = 1
        for j in range(m):
            for i in range(j, N, 2*m):
                U = A[i]; V = (A[i+m]*wk) % Q
                A[i] = (U + V) % Q
                A[i+m] = (U + V*half) % Q
            wk = (wk * w) % Q
        m <<= 1
    return A

def fmt2d(f, l, Q=P):
    tmp = [None]*n
    for i in range(n):
        A = list(f[i])
        bit_reverse(A)
        tmp[i] = fmt_bu(A, n, omega, pow(omega, n//2, Q), Q)
    *tmp, = zip(*tmp)
    F = [None]*n
    for i in range(n):
        A = list(tmp[i])
        bit_reverse(A)
        F[i] = fmt_bu(A, n, omega, pow(omega, n//2, Q), Q)
    *F, = zip(*F)
    return F

def ifmt2d(F, l, Q=P):
    tmp = [None]*n
    for i in range(n):
        A = list(F[i])
        bit_reverse(A)
        tmp[i] = fmt_bu(A, n, rev, pow(omega, n//2, Q), Q)
    *tmp, = zip(*tmp)
    f = [None]*n
    for i in range(n):
        A = list(tmp[i])
        bit_reverse(A)
        f[i] = fmt_bu(A, n, rev, pow(omega, n//2, Q), Q)
    *f, = zip(*f)
    n2_rev = pow(n*n, (Q-2), Q)
    F1 = [[(e * n2_rev) % Q for e in fi] for fi in f]
    return F1

def convolute2d(a, b, l, Q=P):
    A = fmt2d(a, l, Q)
    B = fmt2d(b, l, Q)
    C = [[(s * t) % Q for s, t in zip(Ai, Bi)] for Ai, Bi in zip(A, B)]
    c = ifmt2d(C, l, Q)
    return c

Verified

  • AOJ: "2977 - Bombing": source (Python3, 5.11sec)


戻る