概要

平衡二分探索木を用いてConvex Hull Trickを行う。

最小値をとり得る\(x\)が存在する直線\(y = ax + b\)を傾き降順で管理し、\(x = k\)で最小値を取る直線を二分探索で求める。

Pythonには標準で平衡二分探索木がないため、自力で実装する必要がある。(そして平衡二分探索木は処理が重いためつらい)

計算量

\(O((N + Q) \log N)\)

実装

Treap を用いた実装。

# Treap
import random
random.seed()
class Treap:
    def __init__(self):
        self.root = None
        self.left = self.right = None
        self.ref = {}

    @staticmethod
    def __make(x, prob):
        # <the value of a node>, <the probability of a node>, <# of childs>, <left node>, <right node>, <parent node>, <prev node>, <next node>
        return [x, prob, 0, None, None, None, None, None]

    def empty(self):
        return not self.root

    def find_node(self, x):
        cur = self.root
        parent = cur
        if x in self.ref:
            return self.ref[x]
        while cur:
            if x == cur[0]:
                return cur
            if x < cur[0]:
                parent = cur
                cur = cur[3]
            else:
                parent = cur
                cur = cur[4]
        if parent and parent[0] < x and parent[7]:
            parent = parent[7]
        return parent

    def find_min(self, f):
        cur = self.root
        if not cur[6] and not cur[7]:
            return cur[0]
        left = self.left
        while 1:
            nxt = cur[7]
            if not nxt:
                if not cur[3]:
                    break
                cur = cur[3]
                continue
            val = f(nxt[0]) - f(cur[0])
            if val <= 0:
                left = cur
                if not cur[4]:
                    break
                cur = cur[4]
            else:
                if not cur[3]:
                    break
                cur = cur[3]
        if left[7] and f(left[0]) > f(left[7][0]):
            left = left[7]
        return left[0]

    @staticmethod
    def prev(node):
        return node[6]

    @staticmethod
    def next(node):
        return node[7]

    def remove(self, x):
        cur = self.find_node(x)
        if cur[0] != x:
            return False
        if cur[6]:
            cur[6][7] = cur[7]
        if cur[7]:
            cur[7][6] = cur[6]
        if self.left is cur:
            self.left = cur[7]
        if self.right is cur:
            self.right = cur[6]
        parent = cur[5]
        t = 3 if parent and x < parent[0] else 4
        while 1:
            if cur[3]:
                if cur[4]:
                    if cur[3][1] < cur[4][1]:
                        node = self.left_rotate(cur)
                        p = 3
                    else:
                        node = self.right_rotate(cur)
                        p = 4
                else:
                    node = self.right_rotate(cur)
                    p = 4
            else:
                if cur[4]:
                    node = self.left_rotate(cur)
                    p = 3
                else:
                    break
            if parent:
                parent[t] = node
            else:
                self.root = node
            parent = node; t = p
        if parent:
            parent[t] = None
        else:
            self.root = self.left = self.right = None
        del self.ref[x]
        del cur
        return True

    def insert(self, x):
        if x in self.ref:
            return self.ref[x]
        prob = random.random()
        new_node = self.__make(x, prob)
        self.ref[x] = new_node
        if not self.root:
            self.root = self.left = self.right = new_node
            return
        cur = self.root
        prv = nxt = None
        while 1:
            if x < cur[0]:
                nxt = cur
                if not cur[3]:
                    cur[3] = new_node
                    new_node[5] = cur
                    break
                cur = cur[3]
            else:
                prv = cur
                if not cur[4]:
                    cur[4] = new_node
                    new_node[5] = cur
                    break
                cur = cur[4]
        new_node[6] = prv; new_node[7] = nxt
        if prv:
            prv[7] = new_node
        if nxt:
            nxt[6] = new_node
        if (not self.left) or (x, prob) < (self.left[0], self.left[1]):
            self.left = new_node
        if (not self.right) or (self.right[0], -self.right[1]) < (x, -prob):
            self.right = new_node
        while cur:
            if x < cur[0]:
                cur[3] = new_node
                if prob <= cur[1]: break
                cur = self.right_rotate(cur)[5]
            else:
                cur[4] = new_node
                if prob <= cur[1]: break
                cur = self.left_rotate(cur)[5]
        if not cur:
            self.root = new_node
        return new_node

    @staticmethod
    def right_rotate(node):
        # assert node[3]
        left_node = node[3]
        B = left_node[4]
        left_node[4] = node; left_node[5] = node[5]
        node[3] = B; node[5] = left_node
        if B: B[5] = node
        return left_node

    @staticmethod
    def left_rotate(node):
        # assert node[4]
        right_node = node[4]
        B = right_node[3]
        right_node[3] = node; right_node[5] = node[5]
        node[4] = B; node[5] = right_node
        if B: B[5] = node
        return right_node

    def debug(self):
        tmp = []
        def dfs(node, level):
            if node[3]:
                dfs(node[3], level+1)
            tmp.append(" "*level + "(%d, %f)" % (node[0], node[1]))
            if node[4]:
                dfs(node[4], level+1)
        tmp.append("---")
        dfs(self.root, 0)
        tmp.append("---")
        print(*tmp, sep='\n')
        left = []; right = []
        node = self.root
        while node:
            left.append(node[0])
            node = node[6]
        node = self.root
        while node:
            right.append(node[0])
            node = node[7]
        print(*left, sep='<-')
        print(*right, sep='->')

# Convex Hull Trick: O((N + Q)logN) with Treap
class ConvexHullTrick:
    def __init__(self):
        self.A = Treap()
        self.B = {}

    def check(self, a1, a2, a3):
        B = self.B
        b1 = B[a1]; b2 = B[a2]; b3 = B[a3]
        return (a2-a1)*(b3-b2) >= (b2-b1)*(a3-a2)

    def add(self, a, b):
        B = self.B; A = self.A
        if a in B:
            if B[a] < b:
                return
            B[a] = b
            cur = A.find_node(a)
        else:
            B[a] = b
            node = A.find_node(a)
            if node and node[6]:
                a3 = node[6][0]; a1 = node[0]
                # a1 >= a >= a3
                if a1 >= a >= a3 and self.check(a1, a, a3):
                    del B[a]
                    return
                cur = A.insert(a)
            else:
                cur = A.insert(a)
        if cur and cur[6]:
            p = cur[6]
            while p and p[6] and p[7] and self.check(p[7][0], p[0], p[6][0]):
                v = p[0]; p = p[6]; A.remove(v)
                del B[v]
        if cur and cur[7]:
            p = cur[7]
            while p and p[6] and p[7] and self.check(p[7][0], p[0], p[6][0]):
                v = p[0]; p = p[7]; A.remove(v)
                del B[v]

    def empty(self):
        return not self.B

    def get(self, x):
        B = self.B
        calc = lambda a: a*x + B[a]
        if len(B) == 1:
            a, = B
            return calc(a)
        a = self.A.find_min(calc)
        return calc(a)

Verified

  • CSAcademy: "Squared Ends": source (Python3, TLE)