概要
平衡二分探索木を用いてConvex Hull Trickを行う。
最小値をとり得る \(x\) が存在する直線 \(y = ax + b\) を傾き降順で管理し \(x = k\) で最小値を取る直線を二分探索で求める。
Pythonには標準で平衡二分探索木がないため、自力で実装する必要がある。(そして平衡二分探索木は処理が重いためつらい)
計算量
\(O((N + Q) \log N)\)
実装
Treap を用いた実装。
# Treap
import random
random.seed()
class Treap:
def __init__(self):
self.root = None
self.left = self.right = None
self.ref = {}
@staticmethod
def __make(x, prob):
# <the value of a node>, <the probability of a node>, <# of childs>, <left node>, <right node>, <parent node>, <prev node>, <next node>
return [x, prob, 0, None, None, None, None, None]
def empty(self):
return not self.root
def find_node(self, x):
cur = self.root
parent = cur
if x in self.ref:
return self.ref[x]
while cur:
if x == cur[0]:
return cur
if x < cur[0]:
parent = cur
cur = cur[3]
else:
parent = cur
cur = cur[4]
if parent and parent[0] < x and parent[7]:
parent = parent[7]
return parent
def find_min(self, f):
cur = self.root
if not cur[6] and not cur[7]:
return cur[0]
left = self.left
while 1:
nxt = cur[7]
if not nxt:
if not cur[3]:
break
cur = cur[3]
continue
val = f(nxt[0]) - f(cur[0])
if val <= 0:
left = cur
if not cur[4]:
break
cur = cur[4]
else:
if not cur[3]:
break
cur = cur[3]
if left[7] and f(left[0]) > f(left[7][0]):
left = left[7]
return left[0]
@staticmethod
def prev(node):
return node[6]
@staticmethod
def next(node):
return node[7]
def remove(self, x):
cur = self.find_node(x)
if cur[0] != x:
return False
if cur[6]:
cur[6][7] = cur[7]
if cur[7]:
cur[7][6] = cur[6]
if self.left is cur:
self.left = cur[7]
if self.right is cur:
self.right = cur[6]
parent = cur[5]
t = 3 if parent and x < parent[0] else 4
while 1:
if cur[3]:
if cur[4]:
if cur[3][1] < cur[4][1]:
node = self.left_rotate(cur)
p = 3
else:
node = self.right_rotate(cur)
p = 4
else:
node = self.right_rotate(cur)
p = 4
else:
if cur[4]:
node = self.left_rotate(cur)
p = 3
else:
break
if parent:
parent[t] = node
else:
self.root = node
parent = node; t = p
if parent:
parent[t] = None
else:
self.root = self.left = self.right = None
del self.ref[x]
del cur
return True
def insert(self, x):
if x in self.ref:
return self.ref[x]
prob = random.random()
new_node = self.__make(x, prob)
self.ref[x] = new_node
if not self.root:
self.root = self.left = self.right = new_node
return
cur = self.root
prv = nxt = None
while 1:
if x < cur[0]:
nxt = cur
if not cur[3]:
cur[3] = new_node
new_node[5] = cur
break
cur = cur[3]
else:
prv = cur
if not cur[4]:
cur[4] = new_node
new_node[5] = cur
break
cur = cur[4]
new_node[6] = prv; new_node[7] = nxt
if prv:
prv[7] = new_node
if nxt:
nxt[6] = new_node
if (not self.left) or (x, prob) < (self.left[0], self.left[1]):
self.left = new_node
if (not self.right) or (self.right[0], -self.right[1]) < (x, -prob):
self.right = new_node
while cur:
if x < cur[0]:
cur[3] = new_node
if prob <= cur[1]: break
cur = self.right_rotate(cur)[5]
else:
cur[4] = new_node
if prob <= cur[1]: break
cur = self.left_rotate(cur)[5]
if not cur:
self.root = new_node
return new_node
@staticmethod
def right_rotate(node):
# assert node[3]
left_node = node[3]
B = left_node[4]
left_node[4] = node; left_node[5] = node[5]
node[3] = B; node[5] = left_node
if B: B[5] = node
return left_node
@staticmethod
def left_rotate(node):
# assert node[4]
right_node = node[4]
B = right_node[3]
right_node[3] = node; right_node[5] = node[5]
node[4] = B; node[5] = right_node
if B: B[5] = node
return right_node
def debug(self):
tmp = []
def dfs(node, level):
if node[3]:
dfs(node[3], level+1)
tmp.append(" "*level + "(%d, %f)" % (node[0], node[1]))
if node[4]:
dfs(node[4], level+1)
tmp.append("---")
dfs(self.root, 0)
tmp.append("---")
print(*tmp, sep='\n')
left = []; right = []
node = self.root
while node:
left.append(node[0])
node = node[6]
node = self.root
while node:
right.append(node[0])
node = node[7]
print(*left, sep='<-')
print(*right, sep='->')
# Convex Hull Trick: O((N + Q)logN) with Treap
class ConvexHullTrick:
def __init__(self):
self.A = Treap()
self.B = {}
def check(self, a1, a2, a3):
B = self.B
b1 = B[a1]; b2 = B[a2]; b3 = B[a3]
return (a2-a1)*(b3-b2) >= (b2-b1)*(a3-a2)
def add(self, a, b):
B = self.B; A = self.A
if a in B:
if B[a] < b:
return
B[a] = b
cur = A.find_node(a)
else:
B[a] = b
node = A.find_node(a)
if node and node[6]:
a3 = node[6][0]; a1 = node[0]
# a1 >= a >= a3
if a1 >= a >= a3 and self.check(a1, a, a3):
del B[a]
return
cur = A.insert(a)
else:
cur = A.insert(a)
if cur and cur[6]:
p = cur[6]
while p and p[6] and p[7] and self.check(p[7][0], p[0], p[6][0]):
v = p[0]; p = p[6]; A.remove(v)
del B[v]
if cur and cur[7]:
p = cur[7]
while p and p[6] and p[7] and self.check(p[7][0], p[0], p[6][0]):
v = p[0]; p = p[7]; A.remove(v)
del B[v]
def empty(self):
return not self.B
def get(self, x):
B = self.B
calc = lambda a: a*x + B[a]
if len(B) == 1:
a, = B
return calc(a)
a = self.A.find_min(calc)
return calc(a)
Verified
-
CSAcademy: "Squared Ends": source (Python3, TLE)