binary search

概要

二分探索まわりのメモ

bisectモジュール

import bisect

A = [1, 1, 1, 3, 3, 3, 3, 5, 5, 5]

# x 以下の最大の要素位置
def index_le(A, x):
    return bisect.bisect_right(A, x)-1
print([index_le(A, x) for x in range(0, 7)])
# => "[-1, 2, 2, 6, 6, 9, 9]"

# x 以上の最小の要素位置
def index_ge(A, x):
    return bisect.bisect_left(A, x)
print([index_ge(A, x) for x in range(0, 7)])
# => "[0, 0, 3, 3, 7, 7, 10]"

二分探索の実装

整数値に対する二分探索

A = [1, 1, 1, 3, 3, 3, 3, 5, 5, 5]

# x 以下の最大の要素位置
def index_le(A, x):
    def check(y):
        return A[y] <= x
    left = -1; right = len(A)
    while left+1 < right:
        mid = (left + right) >> 1
        if check(mid):
            # e <= left は check(e) を満たす
            left = mid
        else:
            # right <= e は check(e) を満たさない
            right = mid
    return left
print([index_le(A, x) for x in range(0, 7)])
# => "[-1, 2, 2, 6, 6, 9, 9]"

# x 以上の最小の要素位置
def index_ge(A, x):
    def check(y):
        return x <= A[y]
    left = -1; right = len(A)
    while left+1 < right:
        mid = (left + right) >> 1
        if check(mid):
            # right <= e は check(e) を満たす
            right = mid
        else:
            # e <= left は check(e) を満たさない
            left = mid
    return right
print([index_ge(A, x) for x in range(0, 7)])
# => "[0, 0, 3, 3, 7, 7, 10]"

実数値に対する二分探索

巨大な値を扱う場合は精度より停止しなくなる可能性があるため Decimal 等の利用を考える。

def sqrt(x):
    def check(y):
        return y*y < x
    EPS = 1e-9
    left = 0; right = max(x, 1)
    while right - left > EPS:
        mid = (left + right) / 2
        if check(mid):
            left = mid
        else:
            right = mid
    return left

print(sqrt(0.5))
# => "0.7071067802608013"
print(sqrt(5))
# => "2.236067976919003"
# print(sqrt(5**21)) => 停止しない


from decimal import Decimal
def sqrt_d(x):
    def check(y):
        return y*y < x
    EPS = 1e-9
    left = Decimal(0); right = Decimal(max(x, 1))
    while right - left > EPS:
        mid = (left + right) / 2
        if check(mid):
            left = mid
        else:
            right = mid
    return left
print(sqrt_d(5**21))
# => "21836601.34277138317515660994"