概要

Rolling Hashは文字列のハッシュによって文字列比較を高速に行う。

ハッシュは衝突する可能性もある。衝突させにくくするには、複数のハッシュを計算して比較する等をすればよい。

実装

# 1-dimension Rolling Hash
class RollingHash():
    def __init__(self, s, base, mod):
        self.mod = mod
        self.pw = pw = [1]*(len(s)+1)

        l = len(s)
        self.h = h = [0]*(l+1)

        v = 0
        for i in range(l):
            h[i+1] = v = (v * base + ord(s[i])) % mod
        v = 1
        for i in range(l):
            pw[i+1] = v = v * base % mod
    def get(self, l, r):
        return (self.h[r] - self.h[l] * self.pw[r-l]) % self.mod

# 非クラス版
base = 37; mod = 10**9 + 9
pw = None
def rolling_hash(s):
    l = len(s)
    h = [0]*(l + 1)
    v = 0
    for i in range(l):
        h[i+1] = v = (v * base + ord(s[i])) % mod
    return h
# RH前に、必要な長さの最大値分のpow-tableを計算しておく
def setup_pw(l):
    global pw
    pw = [1]*(l + 1)
    v = 1
    for i in range(l):
        pw[i+1] = v = v * base % mod
def get(h, l, r):
    return (h[r] - h[l] * pw[r-l]) % mod

Verified

  • AOJ: "ALDS1_14_B: String Search - String Search": source (Python3, 1.88sec)

実装 (mod の乱数化)

Rolling HashはHack可能なコンテストで使うと、ハッシュが衝突する入力で落とされる可能性がある。
これに対応するため、乱数を使って狙って衝突させにくくした実装。

# Random Rolling Hash
import random, math
random.seed()
def gen(a, b, num):
    result = set()
    while 1:
        while 1:
            v = random.randint(a, b)//2*2+1
            if v not in result:
                break
        for x in range(3, int(math.sqrt(v))+1, 2):
            if v % x == 0:
                break
        else:
            result.add(v)
            if len(result) == num:
                break
    return result
class RH():
    def __init__(self, s, base, mod):
        self.base = base
        self.mod = mod
        self.rev = pow(base, mod-2, mod)

        l = len(s)
        self.h = h = [0]*(l+1)
        tmp = 0
        for i in range(l):
            num = ord(s[i])
            tmp = (tmp*base + num) % mod
            h[i+1] = tmp
    def calc(self, l, r):
        return (self.h[r] - self.h[l] + self.mod) * pow(self.rev, l, self.mod) % self.mod
class RRH():
    def __init__(self, s, num=10):
        # 2 ~ 1000の間の乱数でnum個のハッシュをとる
        param = (2, 10**3, num)
        MOD = 10**9+7
        self.rhs = [RH(s, p, MOD) for p in gen(*param)]
    def calc(self, l, r):
        return [rh.calc(l, r) for rh in self.rhs]

# usage: RRH("abcd")

実装 (2次元)

Rolling Hash を二次元配列上に拡張したもの

特定の領域中の値が完全一致するかを判定できる。

mod = 10**9 + 9; p = 13; q = 19

p_table = q_table = None
def prepare(L):
    global p_table, q_table
    p_table = [1]*(L+1); q_table = [1]*(L+1)
    for i in range(L):
        p_table[i+1] = p_table[i] * p % mod
        q_table[i+1] = q_table[i] * q % mod

def rolling_hash(S, W, H):
    D = [[0]*(W+1) for i in range(H+1)]
    for i in range(H):
        su = 0
        dp = D[i]
        di = D[i+1]
        si = S[i]
        for j in range(W):
            v = si[j] # v = ord(si[j]) if si[j] is str
            su = (su*p + v) % mod
            di[j+1] = (su + dp[j+1]*q) % mod
    return D
def get(S, x0, y0, x1, y1):
    P = p_table[x1 - x0]; Q = q_table[y1 - y0]
    return (S[y1][x1] - S[y1][x0] * P - S[y0][x1] * Q + S[y0][x0] * (P * Q) % mod) % mod

# example
data1 = [
    [1, 0, 1, 0, 1],
    [1, 1, 1, 0, 1],
    [0, 1, 1, 1, 1],
]
data2 = [
    [1, 0, 1],
    [1, 0, 1],
]
prepare(L = 5)
rh1 = rolling_hash(data1, 5, 3)
rh2 = rolling_hash(data2, 3, 2)
print(get(rh1, 2, 0, 5, 2) == get(rh2, 0, 0, 3, 2))
# => "True"

Verified

  • AOJ: "ALDS1_14_C: Pattern Search": source (Python3, 1.98sec)

  • AOJ: "1312: Where’s Wally": source (Python3, 16.64sec)


戻る