概要

乱択アルゴリズムを用いた平衡二分探索木。 平衡二分探索木の中で実装が簡単なアルゴリズム。

木の高さは\(O(\log N)\)になる。

計算量

各クエリ \(O(\log N)\)

実装

insert-eraseベース

各ノードに以下の情報を持たせている

  • 左右の子ノード

  • key

  • priority

  • そのノード以下の部分木のサイズ

高速に処理するため、非再帰でノードの追加、削除を行うようにしている。

# Treap

# d = 0: right rotation
# d = 1: left rotation
def rotate(nd, d):
    c = nd[d]
    if d:
        e = c[1]
        nd[1] = c[0]
        c[0] = nd
    else:
        e = c[0]
        nd[0] = c[1]
        c[1] = nd

    r = c[4] = nd[4]
    nd[4] = r - (e[4] if e else 0) - 1
    return c

# insert a node with key = val and priority = pri
def insert(t, val, pri):
    st = []
    dr = []
    x = t
    while x:
        st.append(x)
        if x[2] == val:
            return t
        d = (x[2] < val)
        dr.append(d)
        x = x[d]

    # [<left>, <right>, <key>, <priority>, <count>]
    nd = [None, None, val, pri, 1]
    while st:
        x = st.pop(); d = dr.pop()
        x[d] = nd
        x[4] += 1
        if x[3] >= nd[3]:
            break
        rotate(x, d)
    else:
        return nd

    for x in st:
        x[4] += 1
    return t

# delete node nd
def __delete(nd):
    st = []; dr = []
    while nd[0] or nd[1]:
        l = nd[0]; r = nd[1]
        d = (l[3] <= r[3]) if l and r else (l is None)
        st.append(rotate(nd, d))
        dr.append(d ^ 1)
    nd = x = None
    while st:
        nd = x; x = st.pop(); d = dr.pop()
        x[d] = nd
        x[4] -= 1
    return x

# delete a node with key = val
def delete(t, val):
    x = t

    st = []
    y = None
    while x:
        if val == x[2]:
            break
        y = x; d = (x[2] < val)
        st.append(y)
        x = x[d]
    else:
        return

    if y:
        y[d] = __delete(x)
        for x in st:
            x[4] -= 1
        return t
    return __delete(x)

# find a node with key = val
def find(t, val):
    x = t
    while x:
        if val == x[2]:
            return 1
        x = x[x[2] < val]
    return 0

# merge tree l with tree r
def merge(l, r):
    if not l:
        return r
    if not r:
        return l
    nd = [l, r, None, None, l[4]+r[4]+1]
    return __delete(nd)

# split tree t into two trees of size k and |t|-k
def split(t, k):
    x = t
    if not 0 < k < x[4]:
        return t, None
    st = []; dr = []
    while x:
        l = x[0]

        c = (l[4] if l else 0) + 1
        if c == k:
            break
        y = x
        st.append(y)

        if c < k:
            k -= c
            x = x[1]
            d = 1
        else:
            x = l
            d = 0
        dr.append(d)
    # x is k-th element in tree t
    #assert x
    st.append(x)
    dr.append(1)
    x = x[1]
    while x:
        st.append(x)
        dr.append(0)
        x = x[0]
    nd = [None, None, None, None, 1]
    while st:
        x = st.pop(); d = dr.pop()
        x[d] = nd
        x[4] += 1
        rotate(x, d)
    return nd[0], nd[1]

# for debug
def debug(t):
    def dfs(t, k):
        cc = 1
        if t[0]:
            cc += dfs(t[0], k+1)
        print(" "*k, "%d (%d) %d" % (t[2], t[3], t[4]))
        if t[1]:
            cc += t[1] and dfs(t[1], k+1)
        assert t[4] == cc
        return cc
    if not t:
        return 0
    return dfs(t, 0)

Verified

  • AtCoder: "AtCoder Regular Contest 033 - C問題: データ構造": source (PyPy3, 1637ms)

部分木サイズを持たない実装

  • AOJ: "ALDS1_8_D: Binary search trees - Treap": source (Python3, 3.00sec)

参考ページ