Binary Indexed Tree, BIT, Fenwick Tree
概要
部分和と要素の更新のクエリを行う木構造。
配列 \(a_1, a_2, ..., a_N\) を管理するBITは以下のクエリを1回 \(O(\log N)\) で処理できる。
-
部分和 \(a_1 + a_2 + ... + a_i\) を求める
-
\(a_i\) に \(x\) を加える
このBITは、以下の操作が行えるデータ構造とも見なすこともできる
-
\(a_i\) の値を求める
-
\(a_i, a_{i+1}, ..., a_{N}\) に \(x\) を加える
BIT の初期化
\(a_1, a_2, ..., a_N\) の初期値が 0 以外の場合を考える。
\(a_1, a_2, ..., a_k\) の和を \(S_k\) と定義する時、以下により \(O(N)\) で初期化できる。
-
各 \(i = 1, 2, ..., N\) について \(S_k \leftarrow a_k\) で初期化
-
\(i = 1\) から \(N\) まで順に \(S_{i + (i\&-i)} \leftarrow S_i + S_{i + (i\&-i)}\) で更新
二分探索
\(a_1 + a_2 + ... + a_i \le x\) となる最大の \(i\) は \(O(\log N)\) で計算できる。
実装(1次元BIT)
\(Q\) 個のクエリを処理する場合の計算量は \(O(Q \log N)\)
# Binary Indexed Tree (Fenwick Tree)
class BIT:
def __init__(self, n):
self.n = n
self.n0 = 2**(n-1).bit_length()
self.data = [0]*(n+1)
self.el = [0]*(n+1)
def init(self, A):
self.data[1:] = A
for i in range(1, self.n):
if i + (i & -i) <= self.n:
self.data[i + (i & -i)] += self.data[i]
def sum(self, i):
s = 0
while i > 0:
s += self.data[i]
i -= i & -i
return s
def add(self, i, x):
# assert i > 0
self.el[i] += x
while i <= self.n:
self.data[i] += x
i += i & -i
def get(self, i, j=None):
if j is None:
return self.el[i]
return self.sum(j) - self.sum(i)
def lower_bound(self, x):
w = i = 0
k = self.n0
while k:
if i+k <= self.n and w + self.data[i+k] <= x:
w += self.data[i+k]
i += k
k >>= 1
# assert self.get(0, i) <= x < self.get(0, i+1)
return i+1
クラスを使わない単純な実装
# N: クエリ処理する列のサイズ
N = ...
data = [0]*(N+1)
def init(A):
data[1:] = A
for i in range(1, N):
if i + (i & -i) <= N:
data[i + (i & -i)] += data[i]
def add(k, x):
while k <= N:
data[k] += x
k += k & -k
def get(k):
s = 0
while k:
s += data[k]
k -= k & -k
return s
# 二分探索
N0 = 2**(N-1).bit_length()
def lower_bound(x):
w = i = 0
k = N0
while k:
if i+k <= N and w + data[i+k] <= x:
w += data[i+k]
i += k
k >>= 1
return i+1
実装 (2次元BIT)
BITは2次元以上に拡張することができる。
\(H \times W\) の2次元BITにおいて、 \(Q\) 個のクエリを処理する場合の計算量は \(O(Q \log W \log H)\)
# Binary Index Tree (2-dimension)
class BIT2:
# H*W
def __init__(self, h, w):
self.w = w
self.h = h
self.data = [{} for i in range(h+1)]
# O(logH*logW)
def sum(self, i, j):
s = 0
data = self.data
while i > 0:
el = data[i]
k = j
while k > 0:
s += el.get(k, 0)
k -= k & -k
i -= i & -i
return s
# O(logH*logW)
def add(self, i, j, x):
w = self.w; h = self.h
data = self.data
while i <= h:
el = data[i]
k = j
while k <= w:
el[k] = el.get(k, 0) + x
k += k & -k
i += i & -i
# [x0, x1) x [y0, y1)
def range_sum(self, x0, x1, y0, y1):
return self.sum(x1, y1) - self.sum(x1, y0) - self.sum(x0, y1) + self.sum(x0, y0)