概要

Binary Indexed Tree (BIT, Fenwick Tree) は、部分和と要素の更新のクエリを行う木構造である。

配列 \(a_1, a_2, ..., a_N\)を管理するBITは以下のクエリを1回\(O(\log N)\)で処理できる。

  • 部分和 \(a_1 + a_2 + ... + a_i\) を求める

  • \(a_i\)に\(x\)を加える

このBITは、以下の操作が行えるデータ構造とも見なすこともできる

  • \(a_i\) の値を求める

  • \(a_i, a_{i+1}, ..., a_{N}\) に\(x\)を加える

二分探索

\(a_1 + a_2 + ... + a_i \le x\) となる最大の\(i\)は \(O(\log N)\) で計算できる。

実装(1次元BIT)

\(Q\)個のクエリを処理する場合の計算量は\(O(Q \log N)\)

# Binary Indexed Tree (Fenwick Tree)
class BIT:
    def __init__(self, n):
        self.n = n
        self.data = [0]*(n+1)
        self.el = [0]*(n+1)
    def sum(self, i):
        s = 0
        while i > 0:
            s += self.data[i]
            i -= i & -i
        return s
    def add(self, i, x):
        # assert i > 0
        self.el[i] += x
        while i <= self.n:
            self.data[i] += x
            i += i & -i
    def get(self, i, j=None):
        if j is None:
            return self.el[i]
        return self.sum(j) - self.sum(i)

クラスを使わない単純な実装

# N: クエリ処理する列のサイズ

data = [0]*(N+1)
def add(k, x):
    while k <= N:
        data[k] += x
        k += k & -k

def get(k):
    s = 0
    while k:
        s += data[k]
        k -= k & -k
    return s

# 二分探索
N0 = 2**(N-1).bit_length()
def lower_bound(x):
    w = i = 0
    k = N0
    while k:
        if i+k <= N and w + data[i+k] <= x:
            w += data[i+k]
            i += k
        k >>= 1
    return i+1

Verified

  • AOJ: "DSL_2_E: Range Query - Range Add Query (RAQ)": source (Python3, 0.77sec)

  • AtCoder: "Japan Alumni Group Summer Camp 2017 Day 1 - B問題: リス": source (PyPy3, 674ms)

実装 (2次元BIT)

BITは2次元以上に拡張することができる。

\(H \times W\)の2次元BITにおいて、\(Q\)個のクエリを処理する場合の計算量は\(O(Q \log W \log H)\)

# Binary Index Tree (2-dimension)
class BIT2:
    # H*W
    def __init__(self, h, w):
        self.w = w
        self.h = h
        self.data = [{} for i in range(h+1)]

    # O(logH*logW)
    def sum(self, i, j):
        s = 0
        data = self.data
        while i > 0:
            el = data[i]
            k = j
            while k > 0:
                s += el.get(k, 0)
                k -= k & -k
            i -= i & -i
        return s

    # O(logH*logW)
    def add(self, i, j, x):
        w = self.w; h = self.h
        data = self.data
        while i <= h:
            el = data[i]
            k = j
            while k <= w:
                el[k] = el.get(k, 0) + x
                k += k & -k
            i += i & -i

    # [x0, x1) x [y0, y1)
    def range_sum(self, x0, x1, y0, y1):
        return self.sum(x1, y1) - self.sum(x1, y0) - self.sum(x0, y1) + self.sum(x0, y0)

参考ページ