Global Minimum Cut (Stoer-Wagner Algorithm)

概要

重み付き無向グラフ \(G = (V, E)\) において、頂点集合 \(V\) を2つの集合に分割するカット\((S, T)\)のうち、カットする辺の重みの総和が最小となるカットを求める。

Stoer-Wagner Algorithm では、グラフ上の 1つの最小\(s-t\)カット を求めた上でその頂点 \(s, t\) の縮約することを繰り返す。 そして、求めた最小\(s-t\)カットの中で重みの総和が最小となるカットを全域最小カットとして求める。

フィボナッチヒープ等を利用することで 計算量 \(O(|V|(|E| + |V| \log |V|))\) で求められる。

隣接行列における実装

計算量 \(O(|V|^3)\)

def global_minimum_cut(N, E, u0):
    res = 10**18

    # groups = [{i} for i in range(N)]

    merged = [0]*N
    for s in range(N-1):
        # minimum cut phase
        used = [0]*N
        used[u0] = 1
        costs = [0]*N
        for v in range(N):
            if E[u0][v] != -1:
                costs[v] = E[u0][v]
        order = []
        for _ in range(N-1-s):
            v = mc = -1
            for i in range(N):
                if used[i] or merged[i]:
                    continue
                if mc < costs[i]:
                    mc = costs[i]
                    v = i
            # assert v != -1

            # v: the most tightly connected vertex
            for w in range(N):
                if used[w] or E[v][w] == -1:
                    continue
                costs[w] += E[v][w]
            used[v] = 1
            order.append(v)

        v = order[-1]
        ws = 0
        for w in range(N):
            if E[v][w] != -1:
                ws += E[v][w]
        # - the current min-cut is (groups[v], V - groups[v])
        # - the weight of the cut is ws
        res = min(res, ws)

        if len(order) > 1:
            u = order[-2]
            # groups[u].update(groups[v])
            # groups[v] = None

            # merge u and v
            merged[v] = 1
            for w in range(N):
                if w != u:
                    if E[v][w] == -1:
                        continue
                    if E[u][w] != -1:
                        E[u][w] = E[w][u] = E[u][w] + E[v][w]
                    else:
                        E[u][w] = E[w][u] = E[v][w]
                E[v][w] = E[w][v] = -1
    return res

N = 8
es = [
    (0, 1, 2),
    (1, 2, 3),
    (2, 3, 4),
    (0, 4, 3),
    (1, 4, 2),
    (1, 5, 2),
    (2, 6, 2),
    (3, 6, 2),
    (3, 7, 2),
    (4, 5, 3),
    (5, 6, 1),
    (6, 7, 3),
]
E = [[-1]*N for i in range(N)]
for a, b, w in es:
    E[a][b] = E[b][a] = w
print(global_minimum_cut(N, [es[:] for es in E], 1))
# => "4"

二分ヒープを利用した実装

計算量 \(O(|V| |E| \log |V|)\)

from heapq import heapify
class BinaryHeap:
    def __init__(self):
        self.root = None
        self.mp = {}

    # build a heap: O(N)
    def build(self, A):
        A = [(v, k) for k, v in A]
        heapify(A)
        L = len(A)
        # node: [left, right, parent, key, id]
        mp = self.mp
        nds = [[None, None, None, v, k] for v, k in A]
        for i, nd in enumerate(nds):
            nd = nds[i]
            if 2*i+1 < L:
                nd[0] = nds[2*i + 1]
            if 2*i+2 < L:
                nd[1] = nds[i+i + 2]
            if i:
                nd[2] = nds[(i - 1) >> 1]
            mp[nd[4]] = nd
        self.root = nds[0]

    # decrease key: O(\log N)
    def decreasekey(self, k, d):
        node = self.mp[k]
        new_key = node[3] - d
        while node[2] and new_key < node[2][3]:
            node[3] = node[2][3]
            p_id = node[4] = node[2][4]
            self.mp[p_id] = node
            node = node[2]
        node[3] = new_key
        node[4] = k
        self.mp[k] = node

    # pop an item: O(\log N)
    def pop(self):
        target = cur = self.root
        v, k = target[3:5]
        while cur[0] and cur[1]:
            nxt = cur[0] if cur[0][3] < cur[1][3] else cur[1]
            cur[3:5] = nxt[3:5]
            self.mp[cur[4]] = cur
            cur = nxt
        nxt = cur[0] or cur[1]
        if self.root is cur:
            self.root = nxt
            if nxt:
                nxt[2] = None
        else:
            prt = cur[2]
            if prt[0] is cur:
                prt[0] = nxt
            else:
                prt[1] = nxt
            if nxt:
                nxt[2] = prt
        del self.mp[k]
        return k, v

    def empty(self):
        return self.root is None

def global_minimum_cut(N, G0, u0):
    res = 10**18

    G = [{w: d for w, d in G0[v]} for v in range(N)]

    merged = [0]*N
    while 1:
        # minimum cut phase
        used = [0]*N
        used[u0] = 1
        heap = BinaryHeap()

        data = []
        for v in range(N):
            if merged[v] or v == u0:
                continue
            data.append((v, -G[u0].get(v, 0)))
        heap.build(data)

        order = []
        while not heap.empty():
            v, _ = heap.pop()
            for w, d in G[v].items():
                if used[w]:
                    continue
                heap.decreasekey(w, d)
            used[v] = 1
            order.append(v)

        v = order[-1]
        ws = 0
        for w, d in G[v].items():
            ws += d
        res = min(res, ws)

        if len(order) == 1:
            break

        u = order[-2]

        # merge u and v
        merged[v] = 1
        for w, d in G[v].items():
            if w != u:
                if w in G[u]:
                    G[u][w] += d
                    G[w][u] += d
                else:
                    G[u][w] = G[w][u] = d
            del G[w][v]
        G[v] = None
    return res

N = 8
es = [
    (0, 1, 2),
    (1, 2, 3),
    (2, 3, 4),
    (0, 4, 3),
    (1, 4, 2),
    (1, 5, 2),
    (2, 6, 2),
    (3, 6, 2),
    (3, 7, 2),
    (4, 5, 3),
    (5, 6, 1),
    (6, 7, 3),
]
G = [[] for i in range(N)]
for a, b, w in es:
    G[a].append((b, w))
    G[b].append((a, w))
print(global_minimum_cut(N, G, 1))
# => "4"

参考