Fast Modulo Transform (FMT, 高速剰余変換), Number Theoretical Transform (NTT, 数論変換)

概要

ある程度の解説(ブログ): ATC001 - C問題: 高速フーリエ変換 (FMT解法)

mod \(P\)上で \(\omega^i \not = 1\) \((0 < i < N)\), \(\omega^N = 1\) となるような \(\omega, P, N\) を用いて計算する。 この時、 \(P = A \times N + 1\) を満たす。

順変換では \(\displaystyle f_k = \sum_{i=0}^{N-1} a_i \omega^{ik}\) (mod \(P\)) を計算し、

逆変換では \(\displaystyle a_i = \frac{1}{N} \sum_{k=0}^{N-1} f_k \omega^{-ik}\) (mod \(P\)) を計算する。

利用できる原始根と\(P\)

\(P\)

\(P-1\)

原始根

998244353

\(7 \cdot 17 \cdot 2^{23}\)

3

1004535809

\(479 \cdot 2^{21}\)

3

4253024257

\(3 \cdot 13 \cdot 13 \cdot 2^{23}\)

23

原始根\(g\)と\(P\)に対し、\(N = 2^k\) で計算する場合、\(\omega = g^{(P-1)/N}\)とすればよい。

計算量

\(O(N \log N)\)

実装

// 必要な依存関係
#include<algorithm>
#include<cassert>
using ll = long long;
using namespace std;

// x^n mod m
ll fast_pow(ll x, ll n, ll m) {
  ll r = 1;
  while(n) {
    if(n & 1) {
      r = (r * x) % m;
    }
    x = (x * x) % m;
    n >>= 1;
  }
  return r;
}

// Fast Modulo Transform
class FMT {
  int n;
  ll omega, omega_rev;
  ll p, n_rev;
  ll *tmp;

  // bit-reverse
  void bit_reverse(ll *d) {
    int i = 0;
    int ns = n >> 1, nss = n >> 2;
    for(int j = 0; j < ns; j += 2) {
      if(j < i) {
        swap(d[i], d[j]);
        swap(d[i+ns+1], d[j+ns+1]);
      }
      swap(d[i+1], d[j+ns]);
      int k = nss; i ^= k;
      while(k > i) {
        k >>= 1; i ^= k;
      }
    }
  }

  void fmt_calc(ll *a, ll base) {
    int n0 = n, m = 1;
    ll half = fast_pow(base, n/2, p);
    while(n0 > 1) {
      n0 >>= 1;
      ll w = fast_pow(base, n0, p), wk = 1;
      for(int j = 0; j < m; ++j) {
        for(int i = j; i < n; i += 2*m) {
          ll u = a[i], v = (a[i+m] * wk) % p;
          a[i] = (u + v) % p;
          a[i+m] = (u + (v*half) % p) % p;
        }
        wk = (wk * w) % p;
      }
      m <<= 1;
    }
  }
public:
  // p = a * n + 1 (n = 2^k)
  // -> omega^(p-1) ≡  1 (mod p)
  FMT(ll k, ll g, ll p) : p(p) {
    n = 1;
    while(k--) n <<= 1;


    assert((p-1) % n == 0);
    ll a = (p-1) / n;
    omega = fast_pow(g, a, p);
    omega_rev = fast_pow(omega, p-2, p);
    n_rev = fast_pow(n, p-2, p);
    tmp = new ll[n];
  }

  ~FMT() {
    delete tmp;
  }

  void fmt(ll *a, ll *result) {
    for(int i = 0; i < n; ++i) result[i] = a[i];
    bit_reverse(result);
    fmt_calc(result, omega);
  }

  void ifmt(ll *a, ll *result) {
    for(int i = 0; i < n; ++i) result[i] = (a[i] * n_rev) % p;
    bit_reverse(result);
    fmt_calc(result, omega_rev);
  }

  void convolute(ll *a, ll *b, ll *result) {
    fmt(a, tmp);
    fmt(b, result);
    for(int i = 0; i < n; ++i) tmp[i] = (result[i] * tmp[i]) % p;
    ifmt(tmp, result);
  }

  inline int size() const { return n; }
};

// example: FMT(17, 3, 1004535809)

Verified

  • AtCoder: "AtCoder Typical Contest 001 - C問題: 高速フーリエ変換": source (C++14, 436ms)