数据结构精品课 已更新到 V2.1, 手把手刷二叉树系列课程 上线。
LeetCode | 力扣 | 难度 |
---|---|---|
1135. Connecting Cities With Minimum Cost🔒 | 1135. 最低成本联通所有城市🔒 | 🟠 |
1584. Min Cost to Connect All Points | 1584. 连接所有点的最小费用 | 🟠 |
———–
本文是第 7 篇图论算法文章,先列举一下我之前写过的图论算法:
1、 图论算法基础
2、 二分图判定算法
3、 环检测和拓扑排序算法
像图论算法这种高级算法虽然不算难,但是阅读量普遍比较低,我本来是不想写 Prim 算法的,但考虑到算法知识结构的完整性,我还是想把 Prim 算法的坑填上,这样所有经典的图论算法就基本完善了。
Prim 算法和 Kruskal 算法都是经典的最小生成树算法,阅读本文之前,希望你读过前文 Kruskal 最小生成树算法,了解最小生成树的基本定义以及 Kruskal 算法的基本原理,这样就能很容易理解 Prim 算法逻辑了。
图论的最小生成树问题,就是让你从图中找若干边形成一个边的集合 mst
,这些边有以下特性:
1、这些边组成的是一棵树(树和图的区别在于不能包含环)。
2、这些边形成的树要包含所有节点。
3、这些边的权重之和要尽可能小。
那么 Kruskal 算法是使用什么逻辑满足上述条件,计算最小生成树的呢?
首先,Kruskal 算法用到了贪心思想,来满足权重之和尽可能小的问题:
先对所有边按照权重从小到大排序,从权重最小的边开始,选择合适的边加入 mst
集合,这样挑出来的边组成的树就是权重和最小的。
其次,Kruskal 算法用到了 Union-Find 并查集算法,来保证挑选出来的这些边组成的一定是一棵「树」,而不会包含环或者形成一片「森林」:
如果一条边的两个节点已经是连通的,则这条边会使树中出现环;如果最后的连通分量总数大于 1,则说明形成的是「森林」而不是一棵「树」。
那么,本文的主角 Prim 算法是使用什么逻辑来计算最小生成树的呢?
首先,Prim 算法也使用贪心思想来让生成树的权重尽可能小,也就是「切分定理」,这个后文会详细解释。
其次,Prim 算法使用
BFS 算法思想 和 visited
布尔数组避免成环,来保证选出来的边最终形成的一定是一棵树。
Prim 算法不需要事先对所有边排序,而是利用优先级队列动态实现排序的效果,所以我觉得 Prim 算法类似于 Kruskal 的动态过程。
下面介绍一下 Prim 算法的核心原理:切分定理。
「切分」这个术语其实很好理解,就是将一幅图分为两个不重叠且非空的节点集合:
红色的这一刀把图中的节点分成了两个集合,就是一种「切分」,其中被红线切中的的边(标记为蓝色)叫做「横切边」。
记住这两个专业术语的意思,后面我们会频繁使用这两个词,别搞混了。
接下来我们引入「切分定理」:
对于任意一种「切分」,其中权重最小的那条「横切边」一定是构成最小生成树的一条边。
这应该很容易证明,如果一幅加权无向图存在最小生成树,假设下图中用绿色标出来的边就是最小生成树:
那么,你肯定可以找到若干「切分」方式,将这棵最小生成树切成两棵子树。比如下面这种切分:
你会发现,任选一条蓝色的「横切边」都可以将这两棵子树连接起来,构成一棵生成树。
那么为了让最终这棵生成树的权重和最小,你说你要怎么选?
肯定选权重最小的那条「横切边」对吧,这就证明了切分定理。
关于切分定理,你也可以用反证法证明:
给定一幅图的最小生成树,那么随便给一种「切分」,一定至少有一条「横切边」属于最小生成树。
假设这条「横切边」不是权重最小的,那说明最小生成树的权重和就还有再减小的余地,那这就矛盾了,最小生成树的权重和本来就是最小的,怎么再减?所以切分定理是正确的。
有了这个切分定理,你大概就有了一个计算最小生成树的算法思路了:
既然每一次「切分」一定可以找到最小生成树中的一条边,那我就随便切呗,每次都把权重最小的「横切边」拿出来加入最小生成树,直到把构成最小生成树的所有边都切出来为止。
嗯,可以说这就是 Prim 算法的核心思路,不过具体实现起来,还是要有些技巧的。
因为你没办法让计算机理解什么叫「随便切」,所以应该设计机械化的规则和章法来调教你的算法,并尽量减少无用功。
我们思考算法问题时,如果问题的一般情况不好解决,可以从比较简单的特殊情况入手,Prim 算法就是使用的这种思路。
按照「切分」的定义,只要把图中的节点切成两个不重叠且非空的节点集合即可算作一个合法的「切分」,那么我只切出来一个节点,是不是也算是一个合法的「切分」?
是的,这是最简单的「切分」,而且「横切边」也很好确定,就是这个节点的边。
那我们就随便选一个点,假设就从 A
点开始切分:
既然这是一个合法的「切分」,那么按照切分定理,这些「横切边」AB, AF
中权重最小的边一定是最小生成树中的一条边:
好,现在已经找到最小生成树的第一条边(边 AB
),然后呢,如何安排下一次「切分」?
按照 Prim 算法的逻辑,我们接下来可以围绕 A
和 B
这两个节点做切分:
然后又可以从这个切分产生的横切边(图中蓝色的边)中找出权重最小的一条边,也就又找到了最小生成树中的第二条边 BC
:
接下来呢?也是类似的,再围绕着 A, B, C
这三个点做切分,产生的横切边中权重最小的边是 BD
,那么 BD
就是最小生成树的第三条边:
接下来再围绕 A, B, C, D
这四个点做切分……
Prim 算法的逻辑就是这样,每次切分都能找到最小生成树的一条边,然后又可以进行新一轮切分,直到找到最小生成树的所有边为止。
这样设计算法有一个好处,就是比较容易确定每次新的「切分」所产生的「横切边」。
比如回顾刚才的图,当我知道了节点 A, B
的所有「横切边」(不妨表示为 cut({A, B})
),也就是图中蓝色的边:
是否可以快速算出 cut({A, B, C})
,也就是节点 A, B, C
的所有「横切边」有哪些?
是可以的,因为我们发现:
cut({A, B, C}) = cut({A, B}) + cut({C})
而 cut({C})
就是节点 C
的所有邻边:
这个特点使我们用我们写代码实现「切分」和处理「横切边」成为可能:
在进行切分的过程中,我们只要不断把新节点的邻边加入横切边集合,就可以得到新的切分的所有横切边。
当然,细心的读者肯定发现了,cut({A, B})
的横切边和 cut({C})
的横切边中 BC
边重复了。
不过这很好处理,用一个布尔数组 inMST
辅助,防止重复计算横切边就行了。
最后一个问题,我们求横切边的目的是找权重最小的横切边,怎么做到呢?
很简单,用一个优先级队列存储这些横切边,就可以动态计算权重最小的横切边了。
明白了上述算法原理,下面来看一下 Prim 算法的代码实现:
class Prim {
// 核心数据结构,存储「横切边」的优先级队列
private PriorityQueue<int[]> pq;
// 类似 visited 数组的作用,记录哪些节点已经成为最小生成树的一部分
private boolean[] inMST;
// 记录最小生成树的权重和
private int weightSum = 0;
// graph 是用邻接表表示的一幅图,
// graph[s] 记录节点 s 所有相邻的边,
// 三元组 int[]{from, to, weight} 表示一条边
private List<int[]>[] graph;
public Prim(List<int[]>[] graph) {
this.graph = graph;
this.pq = new PriorityQueue<>((a, b) -> {
// 按照边的权重从小到大排序
return a[2] - b[2];
});
// 图中有 n 个节点
int n = graph.length;
this.inMST = new boolean[n];
// 随便从一个点开始切分都可以,我们不妨从节点 0 开始
inMST[0] = true;
cut(0);/**<extend down -100><img src="/algo/images/prim/4.jpeg"> */
// 不断进行切分,向最小生成树中添加边
while (!pq.isEmpty()) {
int[] edge = pq.poll();/**<extend down -100><img src="/algo/images/prim/5.jpeg"> */
int to = edge[1];
int weight = edge[2];
if (inMST[to]) {
// 节点 to 已经在最小生成树中,跳过
// 否则这条边会产生环
continue;
}
// 将边 edge 加入最小生成树
weightSum += weight;
inMST[to] = true;
// 节点 to 加入后,进行新一轮切分,会产生更多横切边
cut(to);/**<extend up -150><img src="/algo/images/prim/9.jpeg"> */
}
}
// 将 s 的横切边加入优先队列
private void cut(int s) {
// 遍历 s 的邻边
for (int[] edge : graph[s]) {
int to = edge[1];
if (inMST[to]) {
// 相邻接点 to 已经在最小生成树中,跳过
// 否则这条边会产生环
continue;
}
// 加入横切边队列
pq.offer(edge);
}
}
// 最小生成树的权重和
public int weightSum() {
return weightSum;
}
// 判断最小生成树是否包含图中的所有节点
public boolean allConnected() {
for (int i = 0; i < inMST.length; i++) {
if (!inMST[i]) {
return false;
}
}
return true;
}
}
// 注意:cpp 代码由 chatGPT🤖 根据我的 java 代码翻译,旨在帮助不同背景的读者理解算法逻辑。
// 本代码还未经过力扣测试,仅供参考,如有疑惑,可以参照我写的 java 代码对比查看。
class Prim {
private:
// 核心数据结构,存储「横切边」的优先级队列
// 三元组 {from, to, weight} 表示一条边
priority_queue<vector<int>, vector<vector<int>>, greater<vector<int>>> pq;
// 类似 visited 数组的作用,记录哪些节点已经成为最小生成树的一部分
vector<bool> inMST;
// 记录最小生成树的权重和
int weightSum = 0;
// graph 是用邻接表表示的一幅图,
// graph[s] 记录节点 s 所有相邻的边
vector<vector<int>>* graph;
public:
Prim(vector<vector<int>>* graph) {
this->graph = graph;
// 图中有 n 个节点
int n = graph->size();
this->inMST.resize(n);
// 随便从一个点开始切分都可以,我们不妨从节点 0 开始
inMST[0] = true;
cut(0);
// 不断进行切分,向最小生成树中添加边
while(!pq.empty()) {
vector<int> edge = pq.top();
pq.pop();
int to = edge[1];
int weight = edge[2];
if (inMST[to]) {
// 节点 to 已经在最小生成树中,跳过
// 否则这条边会产生环
continue;
}
// 将边 edge 加入最小生成树
weightSum += weight;
inMST[to] = true;
// 节点 to 加入后,进行新一轮切分,会产生更多横切边
cut(to);
}
}
// 将 s 的横切边加入优先队列
void cut(int s) {
// 遍历 s 的邻边
for (vector<int>& edge : (*graph)[s]) {
int to = edge[1];
if (inMST[to]) {
// 相邻接点 to 已经在最小生成树中,跳过
// 否则这条边会产生环
continue;
}
// 加入横切边队列
pq.push(edge);
}
}
// 最小生成树的权重和
int weightSum() {
return weightSum;
}
// 判断最小生成树是否包含图中的所有节点
bool allConnected() {
for (bool connected : inMST) {
if (!connected) {
return false;
}
}
return true;
}
};
# 注意:python 代码由 chatGPT🤖 根据我的 java 代码翻译,旨在帮助不同背景的读者理解算法逻辑。
# 本代码还未经过力扣测试,仅供参考,如有疑惑,可以参照我写的 java 代码对比查看。
import heapq
class Prim:
# 核心数据结构,存储「横切边」的优先级队列
def __init__(self, graph: List[List[int]]):
self.graph = graph
self.pq = [] # PriorityQueue<int[]> 的实现
self.inMST = [False] * len(graph) # 类似 visited 数组的作用,记录哪些节点已经成为最小生成树的一部分
self.weightSum = 0 # 记录最小生成树的权重和
self.inMST[0] = True # 随便从一个点开始切分都可以,我们不妨从节点 0 开始
self.cut(0)
# 不断进行切分,向最小生成树中添加边
while self.pq:
# 按照边的权重从小到大排序
edge = heapq.heappop(self.pq)
to = edge[1] # 表示相邻节点
weight = edge[2] # 表示这条边的权重
if self.inMST[to]: # 节点 to 已经在最小生成树中,跳过。否则这条边会产生环
continue
self.weightSum += weight # 将边 edge 加入最小生成树
self.inMST[to] = True
self.cut(to) # 节点 to 加入后,进行新一轮切分,会产生更多横切边
# 将 s 的横切边加入优先队列
def cut(self, s):
for edge in self.graph[s]: # 遍历 s 的邻边
to = edge[1] # 相邻的节点
if self.inMST[to]: # 相邻接点 to 已经在最小生成树中,跳过
continue
heapq.heappush(self.pq, edge) # 加入横切边队列
# 最小生成树的权重和
def weightSum(self) -> int:
return self.weightSum
# 判断最小生成树是否包含图中的所有节点
def allConnected(self) -> bool:
for i in range(len(self.inMST)):
if not self.inMST[i]:
return False
return True
// 注意:go 代码由 chatGPT🤖 根据我的 java 代码翻译,旨在帮助不同背景的读者理解算法逻辑。
// 本代码还未经过力扣测试,仅供参考,如有疑惑,可以参照我写的 java 代码对比查看。
type Prim struct {
// 核心数据结构,存储「横切边」的优先级队列
pq *priorityQueue
// 类似 visited 数组的作用,记录哪些节点已经成为最小生成树的一部分
inMST []bool
// 记录最小生成树的权重和
weightSum int
// graph 是用邻接表表示的一幅图,
// graph[s] 记录节点 s 所有相邻的边,
// 三元组 int[]{from, to, weight} 表示一条边
graph [][]int
}
func newPrim(graph [][]int) *Prim {
pr := &Prim{
graph: graph,
pq: newPriorityQueue(),
}
pq := pr.pq
// 图中有 n 个节点
n := len(graph)
pr.inMST = make([]bool, n)
// 随便从一个点开始切分都可以,我们不妨从节点 0 开始
pr.inMST[0] = true
pr.cut(0)
// 不断进行切分,向最小生成树中添加边
for !pq.isEmpty() {
edge := pq.poll()
to := edge[1]
weight := edge[2]
if pr.inMST[to] {
// 节点 to 已经在最小生成树中,跳过
// 否则这条边会产生环
continue
}
// 将边 edge 加入最小生成树
pr.weightSum += weight
pr.inMST[to] = true
// 节点 to 加入后,进行新一轮切分,会产生更多横切边
pr.cut(to)
}
return pr
}
// 将 s 的横切边加入优先队列
func (pr *Prim) cut(s int) {
pq := pr.pq
// 遍历 s 的邻边
for _, edge := range pr.graph[s] {
to := edge[1]
if pr.inMST[to] {
// 相邻接点 to 已经在最小生成树中,跳过
// 否则这条边会产生环
continue
}
// 加入横切边队列
pq.offer(edge)
}
}
// 最小生成树的权重和
func (pr *Prim) weightSumm() int {
return pr.weightSum
}
// 判断最小生成树是否包含图中的所有节点
func (pr *Prim) allConnected() bool {
for i := 0; i < len(pr.inMST); i++ {
if !pr.inMST[i] {
return false
}
}
return true
}
// 注意:javascript 代码由 chatGPT🤖 根据我的 java 代码翻译,旨在帮助不同背景的读者理解算法逻辑。
// 本代码还未经过力扣测试,仅供参考,如有疑惑,可以参照我写的 java 代码对比查看。
/**
* 核心数据结构,存储「横切边」的优先级队列
* 类似 visited 数组的作用,记录哪些节点已经成为最小生成树的一部分
* 记录最小生成树的权重和
* graph 是用邻接表表示的一幅图,
* graph[s] 记录节点 s 所有相邻的边,
* 三元组 int[]{from, to, weight} 表示一条边
*/
var Prim = function(graph) {
this.graph = graph;
// 按照边的权重从小到大排序
this.pq = new PriorityQueue(function(a, b) {
return a[2] - b[2];
});
// 图中有 n 个节点
var n = graph.length;
this.inMST = new Array(n).fill(false);
// 随便从一个点开始切分都可以,我们不妨从节点 0 开始
this.inMST[0] = true;
this.cut(0);
// 不断进行切分,向最小生成树中添加边
while (!this.pq.isEmpty()) {
var edge = this.pq.poll();
var to = edge[1];
var weight = edge[2];
if (this.inMST[to]) {
// 节点 to 已经在最小生成树中,跳过
// 否则这条边会产生环
continue;
}
// 将边 edge 加入最小生成树
this.weightSum += weight;
this.inMST[to] = true;
// 节点 to 加入后,进行新一轮切分,会产生更多横切边
this.cut(to);
}
};
/**
* 将 s 的横切边加入优先队列
*/
Prim.prototype.cut = function(s) {
// 遍历 s 的邻边
for (var i = 0; i < this.graph[s].length; i++) {
var edge = this.graph[s][i];
var to = edge[1];
if (this.inMST[to]) {
// 相邻接点 to 已经在最小生成树中,跳过
// 否则这条边会产生环
continue;
}
// 加入横切边队列
this.pq.offer(edge);
}
};
/**
* 最小生成树的权重和
*/
Prim.prototype.weightSum = function() {
return this.weightSum;
};
/**
* 判断最小生成树是否包含图中的所有节点
*/
Prim.prototype.allConnected = function() {
for (var i = 0; i < this.inMST.length; i++) {
if (!this.inMST[i]) {
return false;
}
}
return true;
};
明白了切分定理,加上详细的代码注释,你应该能够看懂 Prim 算法的代码了。
这里我们可以再回顾一下本文开头说的 Prim 算法和 Kruskal 算法 的联系:
Kruskal 算法是在一开始的时候就把所有的边排序,然后从权重最小的边开始挑选属于最小生成树的边,组建最小生成树。
Prim 算法是从一个起点的切分(一组横切边)开始执行类似 BFS 算法的逻辑,借助切分定理和优先级队列动态排序的特性,从这个起点「生长」出一棵最小生成树。
说到这里,Prim 算法的时间复杂度是多少呢?
这个不难分析,复杂度主要在优先级队列 pq
的操作上,由于 pq
里面装的是图中的「边」,假设一幅图边的条数为 E
,那么最多操作 O(E)
次 pq
。每次操作优先级队列的时间复杂度取决于队列中的元素个数,取最坏情况就是 O(logE)
。
所以这种 Prim 算法实现的总时间复杂度是 O(ElogE)
。回想一下
Kruskal 算法,它的时间复杂度主要是给所有边按照权重排序,也是 O(ElogE)
。
不过话说回来,和后文 Dijkstra 算法 类似,Prim 算法的时间复杂度也是可以优化的,但优化点在于优先级队列的实现上,和 Prim 算法本身的算法思想关系不大,所以我们这里就不做讨论了,有兴趣的读者可以自行搜索。
接下来,我们实操一波,把之前用 Kruskal 算法解决的力扣题目运用 Prim 算法再解决一遍。
第一题是力扣第 1135 题「 最低成本联通所有城市」,这是一道标准的最小生成树问题:
函数签名如下:
int minimumCost(int n, int[][] connections);
// 注意:cpp 代码由 chatGPT🤖 根据我的 java 代码翻译,旨在帮助不同背景的读者理解算法逻辑。
// 本代码还未经过力扣测试,仅供参考,如有疑惑,可以参照我写的 java 代码对比查看。
int minimumCost(int n, vector<vector<int>>& connections);
# 注意:python 代码由 chatGPT🤖 根据我的 java 代码翻译,旨在帮助不同背景的读者理解算法逻辑。
# 本代码还未经过力扣测试,仅供参考,如有疑惑,可以参照我写的 java 代码对比查看。
def minimumCost(n: int, connections: List[List[int]]) -> int:
// 注意:go 代码由 chatGPT🤖 根据我的 java 代码翻译,旨在帮助不同背景的读者理解算法逻辑。
// 本代码还未经过力扣测试,仅供参考,如有疑惑,可以参照我写的 java 代码对比查看。
func minimumCost(n int, connections [][]int) int {}
// 注意:javascript 代码由 chatGPT🤖 根据我的 java 代码翻译,旨在帮助不同背景的读者理解算法逻辑。
// 本代码还未经过力扣测试,仅供参考,如有疑惑,可以参照我写的 java 代码对比查看。
var minimumCost = function(n, connections) {
// function body here
}
每座城市相当于图中的节点,连通城市的成本相当于边的权重,连通所有城市的最小成本即是最小生成树的权重之和。
那么解法就很明显了,我们先把题目输入的 connections
转化成邻接表形式,然后输入给之前实现的 Prim
算法类即可:
public int minimumCost(int n, int[][] connections) {
// 转化成无向图邻接表的形式
List<int[]>[] graph = buildGraph(n, connections);
// 执行 Prim 算法
Prim prim = new Prim(graph);
if (!prim.allConnected()) {
// 最小生成树无法覆盖所有节点
return -1;
}
return prim.weightSum();
}
List<int[]>[] buildGraph(int n, int[][] connections) {
// 图中共有 n 个节点
List<int[]>[] graph = new LinkedList[n];
for (int i = 0; i < n; i++) {
graph[i] = new LinkedList<>();
}
for (int[] conn : connections) {
// 题目给的节点编号是从 1 开始的,
// 但我们实现的 Prim 算法需要从 0 开始编号
int u = conn[0] - 1;
int v = conn[1] - 1;
int weight = conn[2];
// 「无向图」其实就是「双向图」
// 一条边表示为 int[]{from, to, weight}
graph[u].add(new int[]{u, v, weight});
graph[v].add(new int[]{v, u, weight});
}
return graph;
}
class Prim { /* 见上文 */ }
// 注意:cpp 代码由 chatGPT🤖 根据我的 java 代码翻译,旨在帮助不同背景的读者理解算法逻辑。
// 本代码还未经过力扣测试,仅供参考,如有疑惑,可以参照我写的 java 代码对比查看。
int minimumCost(int n, vector<vector<int>>& connections) {
vector<vector<pair<int, int>>> graph(n); // 转化成无向图邻接表的形式
for (auto& connection: connections) {
int u = connection[0] - 1; // 题目给的节点编号是从 1 开始的
int v = connection[1] - 1;
int weight = connection[2];
graph[u].push_back({v, weight}); // 「无向图」其实就是「双向图」
graph[v].push_back({u, weight});
}
Prim prim(graph);
if (!prim.allConnected()) {
// 最小生成树无法覆盖所有节点
return -1;
}
return prim.weightSum();
}
class Prim {
public:
explicit Prim(vector<vector<pair<int, int>>> &graph): m_graph(graph) {
m_weight.resize(m_graph.size());
m_visited.resize(m_graph.size());
m_parent.resize(m_graph.size(), -1);
for (int i = 0; i < m_graph.size(); ++i) {
m_visited[i] = false;
m_weight[i] = INT_MAX;
}
}
int weightSum() {
return accumulate(m_weight.begin(), m_weight.end(), 0);
}
bool allConnected() {
visit(0);
for (const auto& visited : m_visited) {
if (!visited) {
return false;
}
}
return true;
}
private:
vector<vector<pair<int, int>>> &m_graph;
vector<int> m_weight;
vector<int> m_parent;
vector<bool> m_visited;
void visit(int current) {
const auto& adjacents = m_graph[current];
m_visited[current] = true;
for (const auto& adjacent : adjacents) {
const auto& [node, weight] = adjacent;
if (m_visited[node]) {
continue;
}
if (weight >= m_weight[node]) {
continue;
}
m_weight[node] = weight;
m_parent[node] = current;
visit(node);
}
}
};
# 注意:python 代码由 chatGPT🤖 根据我的 java 代码翻译,旨在帮助不同背景的读者理解算法逻辑。
# 本代码还未经过力扣测试,仅供参考,如有疑惑,可以参照我写的 java 代码对比查看。
def minimumCost(n: int, connections: List[List[int]]) -> int:
def buildGraph(n, connections):
graph = [[] for _ in range(n)] # 图中共有 n 个节点
for conn in connections:
# 题目给的节点编号是从 1 开始的,
# 但我们实现的 Prim 算法需要从 0 开始编号
u = conn[0] - 1
v = conn[1] - 1
weight = conn[2]
# 「无向图」其实就是「双向图」
# 一条边表示为 [from, to, weight]
graph[u].append([u, v, weight])
graph[v].append([v, u, weight])
return graph
# 转化成无向图邻接表的形式
graph = buildGraph(n, connections)
# 执行 Prim 算法
prim = Prim(graph)
if not prim.allConnected():
# 最小生成树无法覆盖所有节点
return -1
return prim.weightSum()
class Prim: # 跟Java的实现一致,略有区别见代码注释
pass
// 注意:go 代码由 chatGPT🤖 根据我的 java 代码翻译,旨在帮助不同背景的读者理解算法逻辑。
// 本代码还未经过力扣测试,仅供参考,如有疑惑,可以参照我写的 java 代码对比查看。
func minimumCost(n int, connections [][]int) int {
// 转化成无向图邻接表的形式
graph := buildGraph(n, connections)
// 执行 Prim 算法
prim := NewPrim(graph)
if !prim.allConnected() {
// 最小生成树无法覆盖所有节点
return -1
}
return prim.weightSum()
}
func buildGraph(n int, connections [][]int) [][][3]int {
// 图中共有 n 个节点
graph := make([][][3]int, n)
for i := 0; i < n; i++ {
graph[i] = make([][3]int, 0)
}
for _, conn := range connections {
// 题目给的节点编号是从 1 开始的,
// 但我们实现的 Prim 算法需要从 0 开始编号
u := conn[0] - 1
v := conn[1] - 1
weight := conn[2]
// 「无向图」其实就是「双向图」
// 一条边表示为 [from, to, weight]
graph[u] = append(graph[u], [3]int{u, v, weight})
graph[v] = append(graph[v], [3]int{v, u, weight})
}
return graph
}
type Prim struct { /* 见上文 */ }
// 注意:javascript 代码由 chatGPT🤖 根据我的 java 代码翻译,旨在帮助不同背景的读者理解算法逻辑。
// 本代码还未经过力扣测试,仅供参考,如有疑惑,可以参照我写的 java 代码对比查看。
var minimumCost = function(n, connections) {
function buildGraph(n, connections) {
var graph = new Array(n);
for (var i = 0; i < n; i++) {
graph[i] = [];
}
for (var j = 0; j < connections.length; j++) {
var conn = connections[j];
var u = conn[0] - 1;
var v = conn[1] - 1;
var weight = conn[2];
graph[u].push([u, v, weight]);
graph[v].push([v, u, weight]);
}
return graph;
}
var graph = buildGraph(n, connections);
var prim = new Prim(graph);
if (!prim.allConnected()) {
return -1;
}
return prim.weightSum();
};
class Prim { /* 见上文 */ }
关于 buildGraph
函数需要注意两点:
一是题目给的节点编号是从 1 开始的,所以我们做一下索引偏移,转化成从 0 开始以便 Prim
类使用;
二是如何用邻接表表示无向加权图,前文 图论算法基础 说过「无向图」其实就可以理解为「双向图」。
这样,我们转化出来的 graph
形式就和之前的 Prim
算法类对应了,可以直接施展 Prim 算法计算最小生成树。
再来看看力扣第 1584 题「 连接所有点的最小费用」:
比如题目给的例子:
points = [[0,0],[2,2],[3,10],[5,2],[7,0]]
算法应该返回 20,按如下方式连通各点:
函数签名如下:
int minCostConnectPoints(int[][] points);
// 注意:cpp 代码由 chatGPT🤖 根据我的 java 代码翻译,旨在帮助不同背景的读者理解算法逻辑。
// 本代码还未经过力扣测试,仅供参考,如有疑惑,可以参照我写的 java 代码对比查看。
int minCostConnectPoints(vector<vector<int>>& points);
# 注意:python 代码由 chatGPT🤖 根据我的 java 代码翻译,旨在帮助不同背景的读者理解算法逻辑。
# 本代码还未经过力扣测试,仅供参考,如有疑惑,可以参照我写的 java 代码对比查看。
def minCostConnectPoints(points: List[List[int]]) -> int:
// 注意:go 代码由 chatGPT🤖 根据我的 java 代码翻译,旨在帮助不同背景的读者理解算法逻辑。
// 本代码还未经过力扣测试,仅供参考,如有疑惑,可以参照我写的 java 代码对比查看。
func minCostConnectPoints(points [][]int) int {}
// 注意:javascript 代码由 chatGPT🤖 根据我的 java 代码翻译,旨在帮助不同背景的读者理解算法逻辑。
// 本代码还未经过力扣测试,仅供参考,如有疑惑,可以参照我写的 java 代码对比查看。
var minCostConnectPoints = function(points) {};
很显然这也是一个标准的最小生成树问题:每个点就是无向加权图中的节点,边的权重就是曼哈顿距离,连接所有点的最小费用就是最小生成树的权重和。
所以我们只要把 points
数组转化成邻接表的形式,即可复用之前实现的 Prim
算法类:
public int minCostConnectPoints(int[][] points) {
int n = points.length;
List<int[]>[] graph = buildGraph(n, points);
return new Prim(graph).weightSum();
}
// 构造无向图
List<int[]>[] buildGraph(int n, int[][] points) {
List<int[]>[] graph = new LinkedList[n];
for (int i = 0; i < n; i++) {
graph[i] = new LinkedList<>();
}
// 生成所有边及权重
for (int i = 0; i < n; i++) {
for (int j = i + 1; j < n; j++) {
int xi = points[i][0], yi = points[i][1];
int xj = points[j][0], yj = points[j][1];
int weight = Math.abs(xi - xj) + Math.abs(yi - yj);
// 用 points 中的索引表示坐标点
graph[i].add(new int[]{i, j, weight});
graph[j].add(new int[]{j, i, weight});
}
}
return graph;
}
class Prim { /* 见上文 */ }
// 注意:cpp 代码由 chatGPT🤖 根据我的 java 代码翻译,旨在帮助不同背景的读者理解算法逻辑。
// 本代码还未经过力扣测试,仅供参考,如有疑惑,可以参照我写的 java 代码对比查看。
int minCostConnectPoints(vector<vector<int>>& points) {
int n = points.size();
vector<vector<int>> graph[n];
for (int i = 0; i < n; i++) {
graph[i] = vector<vector<int>>();
}
for (int i = 0; i < n; i++) {
for (int j = i + 1; j < n; j++) {
int xi = points[i][0], yi = points[i][1];
int xj = points[j][0], yj = points[j][1];
int weight = abs(xi - xj) + abs(yi - yj);
graph[i].push_back(vector<int>{i, j, weight});
graph[j].push_back(vector<int>{j, i, weight});
}
}
return Prim(graph).weightSum();
}
// 构造无向图
class Prim {
public:
Prim(vector<vector<int>>* graph) {
n = graph->size();
visited = vector<bool>(n, false);
dist = vector<int>(n, INT_MAX);
for (int i = 1; i < n; i++) {
pq.push({INT_MAX, i});
}
for (int i = 0; i < n; i++) {
for (auto& edge : (*graph)[i]) {
int v = edge[1], w = edge[2];
adj[i].push_back({ v, w });
}
}
process(0);
}
int weightSum() { return weight_sum; }
private:
int n, weight_sum = 0;
vector<bool> visited;
vector<int> dist;
priority_queue<pair<int, int>, vector<pair<int, int>>, greater<>> pq;
vector<vector<pair<int, int>>> adj = vector<vector<pair<int, int>>> (n);
void process(int u) {
dist[u] = 0;
pq.push({ 0, u });
while (!pq.empty()) {
int x = pq.top().second;
pq.pop();
if (visited[x]) continue;
visited[x] = true;
weight_sum += dist[x];
for (auto& edge : adj[x]) {
int v = edge.first, w = edge.second;
if (!visited[v] && w < dist[v]) {
dist[v] = w;
pq.push({dist[v], v});
}
}
}
}
};
# 注意:python 代码由 chatGPT🤖 根据我的 java 代码翻译,旨在帮助不同背景的读者理解算法逻辑。
# 本代码还未经过力扣测试,仅供参考,如有疑惑,可以参照我写的 java 代码对比查看。
def minCostConnectPoints(points: List[List[int]]) -> int:
n = len(points)
graph = buildGraph(n, points)
return Prim(graph).weightSum()
# 构造无向图
def buildGraph(n: int, points: List[List[int]]) -> List[List[int]]:
graph = [[] for _ in range(n)]
# 生成所有边及权重
for i in range(n):
for j in range(i + 1, n):
xi, yi = points[i]
xj, yj = points[j]
weight = abs(xi - xj) + abs(yi - yj)
# 用 points 中的索引表示坐标点
graph[i].append([i, j, weight])
graph[j].append([j, i, weight])
return graph
class Prim:
def __init__(self, graph: List[List[int]]):
self.graph = graph
self.vertexNum = len(self.graph)
self.visited = [False] * self.vertexNum # 记录顶点是否被访问
self.minWeights = [float('inf')] * self.vertexNum # 记录权值最小的边
self.minWeights[0] = 0 # 从任意一个点出发都可以
def weightSum(self) -> int:
res = 0
for i in range(self.vertexNum):
u = self.getMinWeightVertex()
self.visited[u] = True
res += self.minWeights[u]
# 更新未访问顶点的边
for edge in self.graph[u]:
v, weight = edge[1], edge[2]
if not self.visited[v] and weight < self.minWeights[v]:
self.minWeights[v] = weight
return res
def getMinWeightVertex(self) -> int:
minWeight, minVertex = float('inf'), -1
for i in range(self.vertexNum):
if not self.visited[i] and self.minWeights[i] < minWeight:
minWeight = self.minWeights[i]
minVertex = i
return minVertex
// 注意:go 代码由 chatGPT🤖 根据我的 java 代码翻译,旨在帮助不同背景的读者理解算法逻辑。
// 本代码还未经过力扣测试,仅供参考,如有疑惑,可以参照我写的 java 代码对比查看。
func minCostConnectPoints(points [][]int) int {
n := len(points)
graph := buildGraph(n, points)
return NewPrim(graph).WeightSum()
}
// 构造无向图
func buildGraph(n int, points[][]int) [][][]int {
graph := make([][][]int, n)
for i := 0; i < n; i++ {
graph[i] = make([][]int, 0)
}
// 生成所有边及权重
for i := 0; i < n; i++ {
for j := i + 1; j < n; j++ {
xi, yi := points[i][0], points[i][1]
xj, yj := points[j][0], points[j][1]
weight := Abs(xi - xj) + Abs(yi - yj)
// 用 points 中的索引表示坐标点
graph[i] = append(graph[i], []int{i, j, weight})
graph[j] = append(graph[j], []int{j, i, weight})
}
}
return graph
}
type Prim struct {/* 见上文 */}
// 注意:javascript 代码由 chatGPT🤖 根据我的 java 代码翻译,旨在帮助不同背景的读者理解算法逻辑。
// 本代码还未经过力扣测试,仅供参考,如有疑惑,可以参照我写的 java 代码对比查看。
var minCostConnectPoints = function(points) {
const n = points.length;
const graph = buildGraph(n, points);
return new Prim(graph).weightSum();
};
function buildGraph(n, points) {
const graph = new Array(n);
for (let i = 0; i < n; i++) {
graph[i] = [];
}
// 生成所有边及权重
for (let i = 0; i < n; i++) {
for (let j = i + 1; j < n; j++) {
const xi = points[i][0], yi = points[i][1];
const xj = points[j][0], yj = points[j][1];
const weight = Math.abs(xi - xj) + Math.abs(yi - yj);
// 用 points 中的索引表示坐标点
graph[i].push([i, j, weight]);
graph[j].push([j, i, weight]);
}
}
return graph;
}
class Prim { /* 见上文 */ }
这道题做了一个小的变通:每个坐标点是一个二元组,那么按理说应该用五元组表示一条带权重的边,但这样的话不便执行 Prim 算法;所以我们用 points
数组中的索引代表每个坐标点,这样就可以直接复用之前的 Prim 算法逻辑了。
到这里,Prim 算法就讲完了,整个图论算法也整的差不多了,更多精彩文章,敬请期待。
_____________
《labuladong 的算法小抄》已经出版,关注公众号查看详情;后台回复关键词「进群」可加入算法群;回复「全家桶」可下载配套 PDF 和刷题全家桶:
共同维护高质量学习环境,评论礼仪见这里,违者直接拉黑不解释