最小生成树

综述

本文主要介绍最小生成树的概念以及最小生成树常用的两种算法:Kruskal算法和Prim算法。本文的完整代码可以在我的github找到。

概念

对于一个连通加权无向图G=(V, E)和权重函数$w: E \rightarrow R$,希望找到一个无环子集$T \subseteq E$,使得T中包含V中所有的结点,且$w(T)=\sum_{(u, v) \in T}w(u, v)$的值最小。则(V, T)为最小生成树(最小权重生成树)。

定义无向图$G=(V, E)$的一个切割(S, V-S)时集合A的一个划分。如果一条边$(u, v) \in E$的一个端点位于集合S,另一个端点位于集合V-S,则称该条边横跨切割(S, V-S)。如果集合A中不存在横跨该切割的边,则称该切割尊重集合A。在横跨一个切割的所有边中,权重最小的边称为轻量级边

最小生成树的可以使用贪心策略来生成,这个贪心策略由下面的通用方法来表述。该方法在每个时刻生长最小生成树的一条边,并将其加入到边集合A中,A遵循以下循环不变式:

在每遍循环之前,A时某棵最小生成树的一个子集

在每一步,我们要训责一条边(u, v),将其加入到集合A中,且使得A仍然遵守此循环不变式。由于这样的边没有破坏A的循环不变式,称这样的边对于A是安全的。尊重集合A的切割的轻量级边对于A是安全的。最小生成树的通用方法伪代码如下:

1
2
3
4
5
GENERIC-MST(G, w)
初始化A为空集
while A没有形成一个生成树
找到A的一条安全边,加入到A中
return A

Kruskal 算法

Kruskal算法的基本思路如下:

  1. 初始化边集合A为空集
  2. 初始化每个结点为一棵单结点的树,利用不相交集合来实现
  3. 对所有的边按照升序排列
  4. 取出按照升序排列的一条边(u, v),判断u和v是否属于同一棵树,如果不同,则将u所在的树和v所在的树并起来,将边加入到集合A中。

Kruskal算法始终维护一个森林,森林包含图G所有的结点,每次找到一条连接森林中的两棵树的权重最小的边,将其添加到边集合中。我们假设这样的边为$(u, v)$, $C_1, C_2$为其连接的两棵树,则(u, v)必定为一条安全边。令$(V_{C_1}, V-V_{C_1})$为一个划分,则此划分尊重集合A,则边(u, v)为一棵轻量级边,因此(u, v)为一条安全边。

在Kruskal算法中我们需要使用不相交集合,我们用p、rank数组表示其父亲和秩。其定义如下:

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
def union(x, y, p, r):
return link(find_set(x, p), find_set(y, p), p, r)

def link(x, y, p, r):
xRoot = find_set(x, p)
yRoot = find_set(y, p)
if r[xRoot] > r[yRoot]:
p[yRoot] = xRoot
else:
p[xRoot] = yRoot
if r[yRoot] == r[xRoot]:
r[yRoot] += 1

def find_set(x, p):
if p[x] != x:
p[x] = find_set(p[x], p)

return p[x]

在本次示例中,我们使用邻接链表来表示无向图,其python代码示例如下:

1
2
3
4
5
6
7
8
9
10
11
12
class LinkGraph:
def __init__(self, n):
self.n = n
self.adj = [[] for _ in range(n)]
self.e = []

# 相当于将LinkGraph当成数据结构使用。
def insert(g, e):
for u, v in e:
g.adj[u].append(v)
g.adj[v].append(u)
g.e.append((u, v))

Kruskal算法的python代码如下:

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
def kruskal(g, w):
A = []
# make set
p = list(range(g.n))
r = g.n * [0]
tmp = []
for i, (u, v) in enumerate(g.e):
tmp.append((i, w[u][v]))

tmp = sorted(tmp, key=lambda x: x[1])
for i, iw in tmp:
u, v = g.e[i]
if find_set(u, p) != find_set(v, p):
union(u, v, p, r)
A.append((u, v))

return A

Kruskal算法的时间复杂度为$O(ElogV)$。

Prim算法

Prim算法很朴素,它维护一个边集合A,A中结点集为$V_A$,它直接使用$(V_A, V - V_A)$作为一个划分,则此划分必定尊重集合A。我们只要找到横跨该切割的一条轻量级边(u, v),则(u, v)对A来说是安全。值得说明的是A中的边始终构成一棵树。Prim算法的python代码示例如下:

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
def find(d, flag):
u, min_dist = -1, float('inf')
for i in range(len(d)):
if flag[i] is False and d[i] < min_dist:
u, min_dist = i, d[i]
if u == -1:
return None
return u

def prim(g, w, r):
INF = float('inf')
p = []
flag = []
d = []
for i in range(g.n):
p.append(None)
flag.append(False)
if r == i:
d.append(0)
else:
d.append(INF)

while True:
u = find(d, flag)
if u is None:
break

flag[u] = True
for v in g.adj[u]:
if flag[v] is False and w[u][v] < d[v]:
p[v] = u
d[v] = w[u][v]

return p

以上代码的基本思路如下:

  1. 初始化数组flag、p和d。flag为一个状态数组,用以表示结点是否已经在$V_A$中,若$flag[u]=False$,则结点不在$V_A$中。p用来保存结点的父结点,$p[v]=u$则v的父结点为u。d用来保存结点到$V_A$中所有结点的最短距离。我们初始化源结点r的距离为0。
  2. while循环的每一次循环,都先找一个结点u,则$(p[u], u)$为一条轻量级边,将结点加入$V_A$,即设置flag[u]=True。扫描u的所有邻接结点v,若结点v不在$V_A$中,则将其父亲设为u,即将边(u, v)加入到了A中。同时更新v到结点集$V_A$的最短距离。

分析下上面算法的时间复杂度,首先初始化时间为O(V),while循环了V次,find操作的时间复杂度为O(V)。for循环了2E次,则总时间复杂度为$O(V+V^2+2E)$,即$O(V^2)$。在算法导论中引入小根堆,使得其时间复杂度为O(ElogV)。
不过这依赖于小根堆的设计,下面的python代码是我用小根堆实现的Prim算法。其时间复杂度依赖于数组的index方法,如果index方法的时间复杂度为线性,即为O(V),则下面的算法的时间复杂度反而达不到O(ElogV),反而时间复杂度为O(EV)。如果良好的设计使得for循环中除了维护小根堆性质的rise操作外,时间复杂度均为O(1),则时间复杂度能达到O(ElogV)。

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
def prim(g, w, r):
INF = float('inf')
q = Heap()
p = []
for i in range(g.n):
p.append(None)
if r == i:
q.heappush(0, i)
else:
q.heappush(INF, i)

while q.arr:
_, u = q.heappop()
for v in g.adj[u]:
if v in q.indexs:
index = q.indexs.index(v)
if w[u][v] < q.arr[index]:
p[v] = u
q.arr[index] = w[u][v]
q.rise(index)
return p

# 小根堆的实现
class Heap:
def __init__(self, arr=None):
self.arr = arr if arr else []
self.indexs = list(range(len(arr))) if arr else []
if self.arr:
self.buildheap()

def heappush(self, item, index):
self.arr.append(item)
self.indexs.append(index)
i = len(self.arr) - 1
parentpos = (i - 1) // 2
while i > 0 and self.arr[parentpos] > self.arr[i]:
self.arr[parentpos], self.arr[i] = self.arr[i], self.arr[parentpos]
self.indexs[parentpos], self.indexs[i] = self.indexs[i], self.indexs[parentpos]
i = parentpos
parentpos = (i - 1) // 2

def buildheap(self):
n = len(self.arr)
for i in reversed(range(n // 2)):
self.heapify(i)

def heapify(self, i):
n = len(self.arr)
left = 2 * i + 1
right = left + 1
smallest = i
if left < n and self.arr[left] < self.arr[smallest]:
smallest = left

if right < n and self.arr[right] < self.arr[smallest]:
smallest = right

if smallest != i:
self.arr[smallest], self.arr[i] = self.arr[i], self.arr[smallest]
self.indexs[smallest], self.indexs[i] = self.indexs[i], self.indexs[smallest]
self.heapify(smallest)

def rise(self, i):
n = len(self.arr)
parentpos = (i - 1) // 2
while i > 0 and self.arr[parentpos] > self.arr[i]:
self.arr[parentpos], self.arr[i] = self.arr[i], self.arr[parentpos]
self.indexs[parentpos], self.indexs[i] = self.indexs[i], self.indexs[parentpos]
i = parentpos
parentpos = (i - 1) // 2

def heappop(self):
if len(self.arr) == 0:
return None
elif len(self.arr) == 1:
item = self.arr.pop()
index = self.indexs.pop()
else:
item = self.arr[0]
self.arr[0] = self.arr.pop()
index = self.indexs[0]
self.indexs[0] = self.indexs.pop()
self.heapify(0)

return item, index

Reference

本文主要参考《算法导论》。