概要

ビット01から成る \(N \times N\) のビット行列の演算を行う。

行列演算において、ネックとなるのは積であり \(O(N^3)\) かかる。 さらにPythonではlistの処理が遅いため、行列サイズが \(10^2\) 程度でもTLEすることがある。

Pythonでビットを処理するとき、0もしくは1をlistで管理するよりも、1つの数値の1bitとして管理した方が高速になることが多いため、ビットパラレルアルゴリズムを用いて行列積を高速に計算するようにした。

ビットパラレルアルゴリズムを用いた行列積の計算量は \(O(N^{2+2 \log_2 3})\) 程度になる。(Python内部では多倍長乗算にKaratsuba法が用いられているため)

計算量的に通常の行列積より重そうだが、Python内部で行われる処理が多いので高速になっている。

実装

# ビット行列演算
class BitMatrix:
    __slots__ = ['n', 'mat', 'mask', 'u', 'v']
    def __init__(self, n):
        self.n = n
        self.mat = 0
        self.u = u = 2**n - 1
        self.v = ((u+1)**n - 1) / u
    def copy(self, other):
        #assert self.n == other.n
        self.mat = other.mat
        return self

    def set(self, i, j, b):
        bit = 1 << (i*self.n + j)
        if b:
            self.mat |= bit
        else:
            self.mat &= ~bit
        return self
    def setZ(self):
        self.mat = 0
        return self
    def setI(self):
        n = self.n
        self.mat = (2**((n+1)*n) - 1) / (2**(n+1) - 1)
        return self
    def get(self, i, j):
        return (self.mat >> (i*self.n + j)) & 1
    def __add__(self, other):
        res = BitMatrix(self.n)
        res.mat = self.mat ^ other.mat
        return res
    def __iadd__(self, other):
        n = self.n
        self.mat ^= other.mat
        return self
    def __mul__(self, other):
        n = self.n; u = self.u; v = self.v
        res = BitMatrix(n)
        c = 0; a = self.mat; b = other.mat
        while a and b:
            c ^= ((a & v) * u) & ((b & u) * v)
            a >>= 1; b >>= n
        res.mat = c
        return res
    def __imul__(self, other):
        n = self.n; u = self.u; v = self.v
        c = 0; a = self.mat; b = other.mat
        while a and b:
            c ^= ((a & v) * u) & ((b & u) * v)
            a >>= 1; b >>= n
        self.mat = c
        return self
    def __pow__(self, k):
        res = BitMatrix(self.n).setI()
        A = BitMatrix(self.n).copy(self)
        while k:
            if k&1:
                res *= A
            A *= A
            k >>= 1
        return res
    def __ipow__(self, k):
        if k == 0:
            return self.setI()
        A = BitMatrix(self.n).copy(self)
        k -= 1
        while k:
            if k&1:
                self *= A
            A *= A
            k >>= 1
        return self

Verified

  • AtCoder: "AtCoder Beginner Contest 009 - D問題: 漸化式": source (Python2, 1547ms)


戻る