数据结构精品课 已更新到 V2.1, 手把手刷二叉树系列课程 上线。
LeetCode | 力扣 | 难度 |
---|---|---|
315. Count of Smaller Numbers After Self | 315. 计算右侧小于当前元素的个数 | 🔴 |
327. Count of Range Sum | 327. 区间和的个数 | 🔴 |
493. Reverse Pairs | 493. 翻转对 | 🔴 |
912. Sort an Array | 912. 排序数组 | 🟠 |
———–
一直都有很多读者说,想让我用 框架思维 讲一讲基本的排序算法,我觉得确实得讲讲,毕竟学习任何东西都讲求一个融会贯通,只有对其本质进行比较深刻的理解,才能运用自如。
本文就先讲归并排序,给一套代码模板,然后讲讲它在算法问题中的应用。阅读本文前我希望你读过前文 手把手刷二叉树(纲领篇)。
我在 手把手刷二叉树(第一期) 讲二叉树的时候,提了一嘴归并排序,说归并排序就是二叉树的后序遍历,当时就有很多读者留言说醍醐灌顶。
知道为什么很多读者遇到递归相关的算法就觉得烧脑吗?因为还处在「看山是山,看水是水」的阶段。
就说归并排序吧,如果给你看代码,让你脑补一下归并排序的过程,你脑子里会出现什么场景?
这是一个数组排序算法,所以你脑补一个数组的 GIF,在那一个个交换元素?如果是这样的话,那格局就低了。
但如果你脑海中浮现出的是一棵二叉树,甚至浮现出二叉树后序遍历的场景,那格局就高了,大概率掌握了我经常强调的 框架思维,用这种抽象能力学习算法就省劲多了。
那么,归并排序明明就是一个数组算法,和二叉树有什么关系?接下来我就具体讲讲。
就这么说吧,所有递归的算法,你甭管它是干什么的,本质上都是在遍历一棵(递归)树,然后在节点(前中后序位置)上执行代码,你要写递归算法,本质上就是要告诉每个节点需要做什么。
你看归并排序的代码框架:
// 定义:排序 nums[lo..hi]
void sort(int[] nums, int lo, int hi) {
if (lo == hi) {
return;
}
int mid = (lo + hi) / 2;
// 利用定义,排序 nums[lo..mid]
sort(nums, lo, mid);
// 利用定义,排序 nums[mid+1..hi]
sort(nums, mid + 1, hi);
/****** 后序位置 ******/
// 此时两部分子数组已经被排好序
// 合并两个有序数组,使 nums[lo..hi] 有序
merge(nums, lo, mid, hi);
/*********************/
}
// 将有序数组 nums[lo..mid] 和有序数组 nums[mid+1..hi]
// 合并为有序数组 nums[lo..hi]
void merge(int[] nums, int lo, int mid, int hi);
// 注意:cpp 代码由 chatGPT🤖 根据我的 java 代码翻译,旨在帮助不同背景的读者理解算法逻辑。
// 本代码还未经过力扣测试,仅供参考,如有疑惑,可以参照我写的 java 代码对比查看。
// 定义:排序 nums[lo..hi]
void sort(int[] nums, int lo, int hi) {
if (lo == hi) {
return;
}
int mid = (lo + hi) / 2;
// 利用定义,排序 nums[lo..mid]
sort(nums, lo, mid);
// 利用定义,排序 nums[mid+1..hi]
sort(nums, mid + 1, hi);
/****** 后序位置 ******/
// 此时两部分子数组已经被排好序
// 合并两个有序数组,使 nums[lo..hi] 有序
merge(nums, lo, mid, hi);
/*********************/
}
// 将有序数组 nums[lo..mid] 和有序数组 nums[mid+1..hi]
// 合并为有序数组 nums[lo..hi]
void merge(int[] nums, int lo, int mid, int hi);
# 注意:python 代码由 chatGPT🤖 根据我的 java 代码翻译,旨在帮助不同背景的读者理解算法逻辑。
# 本代码还未经过力扣测试,仅供参考,如有疑惑,可以参照我写的 java 代码对比查看。
# 定义:排序 nums[lo..hi]
def sort(nums: List[int], lo: int, hi: int) -> None:
if lo == hi:
return
mid = (lo + hi) // 2
# 利用定义,排序 nums[lo..mid]
sort(nums, lo, mid)
# 利用定义,排序 nums[mid+1..hi]
sort(nums, mid + 1, hi)
/****** 后序位置 ******/
# 此时两部分子数组已经被排好序
# 合并两个有序数组,使 nums[lo..hi] 有序
merge(nums, lo, mid, hi)
/*********************/
# 将有序数组 nums[lo..mid] 和有序数组 nums[mid+1..hi]
# 合并为有序数组 nums[lo..hi]
def merge(nums: List[int], lo: int, mid: int, hi: int) -> None:
pass
// 注意:go 代码由 chatGPT🤖 根据我的 java 代码翻译,旨在帮助不同背景的读者理解算法逻辑。
// 本代码还未经过力扣测试,仅供参考,如有疑惑,可以参照我写的 java 代码对比查看。
// 定义:排序 nums[lo..hi]
func sort(nums []int, lo, hi int) {
if lo == hi {
return
}
mid := (lo + hi) / 2
// 利用定义,排序 nums[lo..mid]
sort(nums, lo, mid)
// 利用定义,排序 nums[mid+1..hi]
sort(nums, mid+1, hi)
/****** 后序位置 ******/
// 此时两部分子数组已经被排好序
// 合并两个有序数组,使 nums[lo..hi] 有序
merge(nums, lo, mid, hi)
/*********************/
}
// 将有序数组 nums[lo..mid] 和有序数组 nums[mid+1..hi]
// 合并为有序数组 nums[lo..hi]
func merge(nums []int, lo, mid, hi int) {
// 实现略
}
// 注意:javascript 代码由 chatGPT🤖 根据我的 java 代码翻译,旨在帮助不同背景的读者理解算法逻辑。
// 本代码还未经过力扣测试,仅供参考,如有疑惑,可以参照我写的 java 代码对比查看。
// 定义:排序 nums[lo..hi]
function sort(nums, lo, hi) {
if (lo == hi) {
return;
}
var mid = Math.floor((lo + hi) / 2);
// 利用定义,排序 nums[lo..mid]
sort(nums, lo, mid);
// 利用定义,排序 nums[mid+1..hi]
sort(nums, mid + 1, hi);
/****** 后序位置 ******/
// 此时两部分子数组已经被排好序
// 合并两个有序数组,使 nums[lo..hi] 有序
merge(nums, lo, mid, hi);
/*********************/
}
// 将有序数组 nums[lo..mid] 和有序数组 nums[mid+1..hi]
// 合并为有序数组 nums[lo..hi]
function merge(nums, lo, mid, hi);
看这个框架,也就明白那句经典的总结:归并排序就是先把左半边数组排好序,再把右半边数组排好序,然后把两半数组合并。
上述代码和二叉树的后序遍历很像:
/* 二叉树遍历框架 */
void traverse(TreeNode root) {
if (root == null) {
return;
}
traverse(root.left);
traverse(root.right);
/****** 后序位置 ******/
print(root.val);
/*********************/
}
// 注意:cpp 代码由 chatGPT🤖 根据我的 java 代码翻译,旨在帮助不同背景的读者理解算法逻辑。
// 本代码还未经过力扣测试,仅供参考,如有疑惑,可以参照我写的 java 代码对比查看。
/* 二叉树遍历框架 */
void traverse(TreeNode* root) {
if (root == nullptr) {
return;
}
traverse(root->left);
traverse(root->right);
/****** 后序位置 ******/
cout << root->val;
/*********************/
}
# 注意:python 代码由 chatGPT🤖 根据我的 java 代码翻译,旨在帮助不同背景的读者理解算法逻辑。
# 本代码还未经过力扣测试,仅供参考,如有疑惑,可以参照我写的 java 代码对比查看。
# 二叉树遍历框架
def traverse(root: TreeNode) -> None:
if root is None:
return
traverse(root.left)
traverse(root.right)
# 后序位置
print(root.val)
// 注意:go 代码由 chatGPT🤖 根据我的 java 代码翻译,旨在帮助不同背景的读者理解算法逻辑。
// 本代码还未经过力扣测试,仅供参考,如有疑惑,可以参照我写的 java 代码对比查看。
// 二叉树遍历框架
func traverse(root *TreeNode) {
if root == nil {
return
}
traverse(root.Left)
traverse(root.Right)
// 后序位置
fmt.Println(root.Val)
}
// 注意:javascript 代码由 chatGPT🤖 根据我的 java 代码翻译,旨在帮助不同背景的读者理解算法逻辑。
// 本代码还未经过力扣测试,仅供参考,如有疑惑,可以参照我写的 java 代码对比查看。
/* 二叉树遍历框架 */
var traverse = function(root) {
if (root === null) {
return;
}
traverse(root.left);
traverse(root.right);
/****** 后序位置 ******/
console.log(root.val);
/*********************/
};
再进一步,你联想一下求二叉树的最大深度的算法代码:
// 定义:输入根节点,返回这棵二叉树的最大深度
int maxDepth(TreeNode root) {
if (root == null) {
return 0;
}
// 利用定义,计算左右子树的最大深度
int leftMax = maxDepth(root.left);
int rightMax = maxDepth(root.right);
// 整棵树的最大深度等于左右子树的最大深度取最大值,
// 然后再加上根节点自己
int res = Math.max(leftMax, rightMax) + 1;
return res;
}
// 注意:cpp 代码由 chatGPT🤖 根据我的 java 代码翻译,旨在帮助不同背景的读者理解算法逻辑。
// 本代码还未经过力扣测试,仅供参考,如有疑惑,可以参照我写的 java 代码对比查看。
// 定义:输入根节点,返回这棵二叉树的最大深度
int maxDepth(TreeNode* root) {
if (root == NULL) {
return 0;
}
// 利用定义,计算左右子树的最大深度
int leftMax = maxDepth(root->left);
int rightMax = maxDepth(root->right);
// 整棵树的最大深度等于左右子树的最大深度取最大值,
// 然后再加上根节点自己
int res = std::max(leftMax, rightMax) + 1;
return res;
}
# 注意:python 代码由 chatGPT🤖 根据我的 java 代码翻译,旨在帮助不同背景的读者理解算法逻辑。
# 本代码还未经过力扣测试,仅供参考,如有疑惑,可以参照我写的 java 代码对比查看。
# 定义:输入根节点,返回这棵二叉树的最大深度
def maxDepth(root: TreeNode) -> int:
if not root:
return 0
# 利用定义,计算左右子树的最大深度
leftMax = maxDepth(root.left)
rightMax = maxDepth(root.right)
# 整棵树的最大深度等于左右子树的最大深度取最大值,
# 然后再加上根节点自己
res = max(leftMax, rightMax) + 1
return res
// 注意:go 代码由 chatGPT🤖 根据我的 java 代码翻译,旨在帮助不同背景的读者理解算法逻辑。
// 本代码还未经过力扣测试,仅供参考,如有疑惑,可以参照我写的 java 代码对比查看。
// 定义:输入根节点,返回这棵二叉树的最大深度
func maxDepth(root *TreeNode) int {
if root == nil {
return 0
}
// 利用定义,计算左右子树的最大深度
leftMax := maxDepth(root.Left)
rightMax := maxDepth(root.Right)
// 整棵树的最大深度等于左右子树的最大深度取最大值,
// 然后再加上根节点自己
res := max(leftMax, rightMax) + 1
return res
}
func max(a, b int) int {
if a > b {
return a
}
return b
}
// 注意:javascript 代码由 chatGPT🤖 根据我的 java 代码翻译,旨在帮助不同背景的读者理解算法逻辑。
// 本代码还未经过力扣测试,仅供参考,如有疑惑,可以参照我写的 java 代码对比查看。
// 定义:输入根节点,返回这棵二叉树的最大深度
var maxDepth = function(root) {
if (root === null) {
return 0;
}
// 利用定义,计算左右子树的最大深度
var leftMax = maxDepth(root.left);
var rightMax = maxDepth(root.right);
// 整棵树的最大深度等于左右子树的最大深度取最大值,
// 然后再加上根节点自己
var res = Math.max(leftMax, rightMax) + 1;
return res;
}
是不是更像了?
前文 手把手刷二叉树(纲领篇) 说二叉树问题可以分为两类思路,一类是遍历一遍二叉树的思路,另一类是分解问题的思路,根据上述类比,显然归并排序利用的是分解问题的思路( 分治算法)。
归并排序的过程可以在逻辑上抽象成一棵二叉树,树上的每个节点的值可以认为是 nums[lo..hi]
,叶子节点的值就是数组中的单个元素:
然后,在每个节点的后序位置(左右子节点已经被排好序)的时候执行 merge
函数,合并两个子节点上的子数组:
这个 merge
操作会在二叉树的每个节点上都执行一遍,执行顺序是二叉树后序遍历的顺序。
后序遍历二叉树大家应该已经烂熟于心了,就是下图这个遍历顺序:
结合上述基本分析,我们把 nums[lo..hi]
理解成二叉树的节点,sort
函数理解成二叉树的遍历函数,整个归并排序的执行过程就是以下 GIF 描述的这样:
这样,归并排序的核心思路就分析完了,接下来只要把思路翻译成代码就行。
只要拥有了正确的思维方式,理解算法思路是不困难的,但把思路实现成代码,也很考验一个人的编程能力。
毕竟算法的时间复杂度只是一个理论上的衡量标准,而算法的实际运行效率要考虑的因素更多,比如应该避免内存的频繁分配释放,代码逻辑应尽可能简洁等等。
经过对比,《算法 4》中给出的归并排序代码兼具了简洁和高效的特点,所以我们可以参考书中给出的代码作为归并算法模板:
class Merge {
// 用于辅助合并有序数组
private static int[] temp;
public static void sort(int[] nums) {
// 先给辅助数组开辟内存空间
temp = new int[nums.length];
// 排序整个数组(原地修改)
sort(nums, 0, nums.length - 1);
}
// 定义:将子数组 nums[lo..hi] 进行排序
private static void sort(int[] nums, int lo, int hi) {
if (lo == hi) {
// 单个元素不用排序
return;
}
// 这样写是为了防止溢出,效果等同于 (hi + lo) / 2
int mid = lo + (hi - lo) / 2;
// 先对左半部分数组 nums[lo..mid] 排序
sort(nums, lo, mid);
// 再对右半部分数组 nums[mid+1..hi] 排序
sort(nums, mid + 1, hi);
// 将两部分有序数组合并成一个有序数组
merge(nums, lo, mid, hi);
}
// 将 nums[lo..mid] 和 nums[mid+1..hi] 这两个有序数组合并成一个有序数组
private static void merge(int[] nums, int lo, int mid, int hi) {
// 先把 nums[lo..hi] 复制到辅助数组中
// 以便合并后的结果能够直接存入 nums
for (int i = lo; i <= hi; i++) {
temp[i] = nums[i];
}
// 数组双指针技巧,合并两个有序数组
int i = lo, j = mid + 1;
for (int p = lo; p <= hi; p++) {
if (i == mid + 1) {
// 左半边数组已全部被合并
nums[p] = temp[j++];
} else if (j == hi + 1) {
// 右半边数组已全部被合并
nums[p] = temp[i++];
} else if (temp[i] > temp[j]) {
nums[p] = temp[j++];
} else {
nums[p] = temp[i++];
}
}
}
}
有了之前的铺垫,这里只需要着重讲一下这个 merge
函数。
sort
函数对 nums[lo..mid]
和 nums[mid+1..hi]
递归排序完成之后,我们没有办法原地把它俩合并,所以需要 copy 到 temp
数组里面,然后通过类似于前文
单链表的六大技巧 中合并有序链表的双指针技巧将 nums[lo..hi]
合并成一个有序数组:
注意我们不是在 merge
函数执行的时候 new 辅助数组,而是提前把 temp
辅助数组 new 出来了,这样就避免了在递归中频繁分配和释放内存可能产生的性能问题。
再说一下归并排序的时间复杂度,虽然大伙儿应该都知道是 O(NlogN)
,但不见得所有人都知道这个复杂度怎么算出来的。
前文
动态规划详解 说过递归算法的复杂度计算,就是子问题个数 x 解决一个子问题的复杂度。对于归并排序来说,时间复杂度显然集中在 merge
函数遍历 nums[lo..hi]
的过程,但每次 merge
输入的 lo
和 hi
都不同,所以不容易直观地看出时间复杂度。
merge
函数到底执行了多少次?每次执行的时间复杂度是多少?总的时间复杂度是多少?这就要结合之前画的这幅图来看:
执行的次数是二叉树节点的个数,每次执行的复杂度就是每个节点代表的子数组的长度,所以总的时间复杂度就是整棵树中「数组元素」的个数。
所以从整体上看,这个二叉树的高度是 logN
,其中每一层的元素个数就是原数组的长度 N
,所以总的时间复杂度就是 O(NlogN)
。
力扣第 912 题「 排序数组」就是让你对数组进行排序,我们可以直接套用归并排序代码模板:
class Solution {
public int[] sortArray(int[] nums) {
// 归并排序对数组进行原地排序
Merge.sort(nums);
return nums;
}
}
class Merge {
// 见上文
}
// 注意:cpp 代码由 chatGPT🤖 根据我的 java 代码翻译,旨在帮助不同背景的读者理解算法逻辑。
// 本代码还未经过力扣测试,仅供参考,如有疑惑,可以参照我写的 java 代码对比查看。
class Solution {
public:
vector<int> sortArray(vector<int>& nums) {
// 归并排序对数组进行原地排序
Merge::sort(nums);
return nums;
}
};
class Merge {
public:
static void sort(vector<int>& nums) {
// 见上文
}
};
# 注意:python 代码由 chatGPT🤖 根据我的 java 代码翻译,旨在帮助不同背景的读者理解算法逻辑。
# 本代码还未经过力扣测试,仅供参考,如有疑惑,可以参照我写的 java 代码对比查看。
class Solution:
def sortArray(self, nums: List[int]) -> List[int]:
# 归并排序对数组进行原地排序
self.mergeSort(nums)
return nums
def mergeSort(self, nums: List[int]):
# 见上文
pass
// 注意:go 代码由 chatGPT🤖 根据我的 java 代码翻译,旨在帮助不同背景的读者理解算法逻辑。
// 本代码还未经过力扣测试,仅供参考,如有疑惑,可以参照我写的 java 代码对比查看。
func sortArray(nums []int) []int {
// 归并排序对数组进行原地排序
MergeSort(nums)
return nums
}
func MergeSort(nums []int) {
// 见上文
}
// 注意:javascript 代码由 chatGPT🤖 根据我的 java 代码翻译,旨在帮助不同背景的读者理解算法逻辑。
// 本代码还未经过力扣测试,仅供参考,如有疑惑,可以参照我写的 java 代码对比查看。
var sortArray = function (nums) {
Merge.sort(nums);
return nums;
};
class Merge {
static sort(nums) {
// 实现归并排序
}
}
除了最基本的排序问题,归并排序还可以用来解决力扣第 315 题「 计算右侧小于当前元素的个数」:
我用比较数学的语言来描述一下(方便和后续类似题目进行对比),题目让你求出一个 count
数组,使得:
count[i] = COUNT(j) where j > i and nums[j] < nums[i]
拍脑袋的暴力解法就不说了,嵌套 for 循环,平方级别的复杂度。
这题和归并排序什么关系呢,主要在 merge
函数,我们在使用 merge
函数合并两个有序数组的时候,其实是可以知道一个元素 nums[i]
后边有多少个元素比 nums[i]
小的。
具体来说,比如这个场景:
这时候我们应该把 temp[i]
放到 nums[p]
上,因为 temp[i] < temp[j]
。
但就在这个场景下,我们还可以知道一个信息:5 后面比 5 小的元素个数就是 左闭右开区间 [mid + 1, j)
中的元素个数,即 2 和 4 这两个元素:
换句话说,在对 nums[lo..hi]
合并的过程中,每当执行 nums[p] = temp[i]
时,就可以确定 temp[i]
这个元素后面比它小的元素个数为 j - mid - 1
。
当然,nums[lo..hi]
本身也只是一个子数组,这个子数组之后还会被执行 merge
,其中元素的位置还是会改变。但这是其他递归节点需要考虑的问题,我们只要在 merge
函数中做一些手脚,叠加每次 merge
时记录的结果即可。
发现了这个规律后,我们只要在 merge
中添加两行代码即可解决这个问题,看解法代码:
class Solution {
private class Pair {
int val, id;
Pair(int val, int id) {
// 记录数组的元素值
this.val = val;
// 记录元素在数组中的原始索引
this.id = id;
}
}
// 归并排序所用的辅助数组
private Pair[] temp;
// 记录每个元素后面比自己小的元素个数
private int[] count;
// 主函数
public List<Integer> countSmaller(int[] nums) {
int n = nums.length;
count = new int[n];
temp = new Pair[n];
Pair[] arr = new Pair[n];
// 记录元素原始的索引位置,以便在 count 数组中更新结果
for (int i = 0; i < n; i++)
arr[i] = new Pair(nums[i], i);
// 执行归并排序,本题结果被记录在 count 数组中
sort(arr, 0, n - 1);
List<Integer> res = new LinkedList<>();
for (int c : count) res.add(c);
return res;
}
// 归并排序
private void sort(Pair[] arr, int lo, int hi) {
if (lo == hi) return;
int mid = lo + (hi - lo) / 2;
sort(arr, lo, mid);
sort(arr, mid + 1, hi);
merge(arr, lo, mid, hi);
}
// 合并两个有序数组
private void merge(Pair[] arr, int lo, int mid, int hi) {
for (int i = lo; i <= hi; i++) {
temp[i] = arr[i];
}
int i = lo, j = mid + 1;
for (int p = lo; p <= hi; p++) {
if (i == mid + 1) {
arr[p] = temp[j++];
} else if (j == hi + 1) {
arr[p] = temp[i++];
// 更新 count 数组
count[arr[p].id] += j - mid - 1;
} else if (temp[i].val > temp[j].val) {
arr[p] = temp[j++];
} else {
arr[p] = temp[i++];
// 更新 count 数组
count[arr[p].id] += j - mid - 1;
}
}
}
}
// 注意:cpp 代码由 chatGPT🤖 根据我的 java 代码翻译,旨在帮助不同背景的读者理解算法逻辑。
// 本代码还未经过力扣测试,仅供参考,如有疑惑,可以参照我写的 java 代码对比查看。
class Solution {
public:
struct Pair {
int val, id;
Pair(int val, int id) {
// 记录数组的元素值
this->val = val;
// 记录元素在数组中的原始索引
this->id = id;
}
};
// 归并排序所用的辅助数组
Pair* temp;
// 记录每个元素后面比自己小的元素个数
int* count;
// 主函数
vector<int> countSmaller(vector<int>& nums) {
int n = nums.size();
count = new int[n]();
temp = new Pair[n]();
Pair* arr = new Pair[n];
// 记录元素原始的索引位置,以便在 count 数组中更新结果
for (int i = 0; i < n; i++)
arr[i] = Pair(nums[i], i);
// 执行归并排序,本题结果被记录在 count 数组中
sort(arr, 0, n - 1);
vector<int> res;
for (int i = 0; i < n; i++)
res.push_back(count[i]);
delete[] count;
delete[] temp;
return res;
}
// 归并排序
void sort(Pair* arr, int lo, int hi) {
if (lo == hi) return;
int mid = lo + (hi - lo) / 2;
sort(arr, lo, mid);
sort(arr, mid + 1, hi);
merge(arr, lo, mid, hi);
}
// 合并两个有序数组
void merge(Pair* arr, int lo, int mid, int hi) {
for (int i = lo; i <= hi; i++) {
temp[i] = arr[i];
}
int i = lo, j = mid + 1;
for (int p = lo; p <= hi; p++) {
if (i == mid + 1) {
arr[p] = temp[j++];
} else if (j == hi + 1) {
arr[p] = temp[i++];
// 更新 count 数组
count[arr[p].id] += j - mid - 1;
} else if (temp[i].val > temp[j].val) {
arr[p] = temp[j++];
} else {
arr[p] = temp[i++];
// 更新 count 数组
count[arr[p].id] += j - mid - 1;
}
}
}
};
# 注意:python 代码由 chatGPT🤖 根据我的 java 代码翻译,旨在帮助不同背景的读者理解算法逻辑。
# 本代码还未经过力扣测试,仅供参考,如有疑惑,可以参照我写的 java 代码对比查看。
from typing import List
class Solution:
def countSmaller(self, nums: List[int]) -> List[int]:
class Pair:
def __init__(self, val: int, id: int):
# 记录数组的元素值
self.val = val
# 记录元素在数组中的原始索引
self.id = id
# 归并排序所用的辅助数组
temp = [Pair(0,0) for _ in range(len(nums))]
# 记录每个元素后面比自己小的元素个数
count = [0 for _ in range(len(nums))]
# 归并排序
def sort(arr:List[Pair], lo:int, hi:int) -> None:
if lo == hi:
return
mid = lo + (hi - lo) // 2
sort(arr, lo, mid)
sort(arr, mid + 1, hi)
merge(arr, lo, mid, hi)
# 合并两个有序数组
def merge(arr:List[Pair], lo:int, mid:int, hi:int) -> None:
for i in range(lo, hi + 1):
temp[i] = arr[i]
i, j = lo, mid + 1
for p in range(lo, hi + 1):
if i == mid + 1:
arr[p] = temp[j]
j += 1
elif j == hi + 1:
arr[p] = temp[i]
i += 1
# 更新 count 数组
count[arr[p].id] += j - mid - 1
elif temp[i].val > temp[j].val:
arr[p] = temp[j]
j += 1
else:
arr[p] = temp[i]
i += 1
# 更新 count 数组
count[arr[p].id] += j - mid - 1
n = len(nums)
arr = [Pair(nums[i], i) for i in range(n)]
# 执行归并排序,本题结果被记录在 count 数组中
sort(arr, 0, n - 1)
res = [c for c in count]
return res
// 注意:go 代码由 chatGPT🤖 根据我的 java 代码翻译,旨在帮助不同背景的读者理解算法逻辑。
// 本代码还未经过力扣测试,仅供参考,如有疑惑,可以参照我写的 java 代码对比查看。
class Solution {
private class Pair {
int val, id;
Pair(int val, int id) {
// 记录数组的元素值
this.val = val;
// 记录元素在数组中的原始索引
this.id = id;
}
}
// 归并排序所用的辅助数组
private Pair[] temp;
// 记录每个元素后面比自己小的元素个数
private int[] count;
// 主函数
public List<Integer> countSmaller(int[] nums) {
int n = nums.length;
count = new int[n];
temp = new Pair[n];
Pair[] arr = new Pair[n];
// 记录元素原始的索引位置,以便在 count 数组中更新结果
for (int i = 0; i < n; i++)
arr[i] = new Pair(nums[i], i);
// 执行归并排序,本题结果被记录在 count 数组中
sort(arr, 0, n - 1);
List<Integer> res = new LinkedList<>();
for (int c : count) res.add(c);
return res;
}
// 归并排序
private void sort(Pair[] arr, int lo, int hi) {
if (lo == hi) return;
int mid = lo + (hi - lo) / 2;
sort(arr, lo, mid);
sort(arr, mid + 1, hi);
merge(arr, lo, mid, hi);
}
// 合并两个有序数组
private void merge(Pair[] arr, int lo, int mid, int hi) {
for (int i = lo; i <= hi; i++) {
temp[i] = arr[i];
}
int i = lo, j = mid + 1;
for (int p = lo; p <= hi; p++) {
if (i == mid + 1) {
arr[p] = temp[j++];
} else if (j == hi + 1) {
arr[p] = temp[i++];
// 更新 count 数组
count[arr[p].id] += j - mid - 1;
} else if (temp[i].val > temp[j].val) {
arr[p] = temp[j++];
} else {
arr[p] = temp[i++];
// 更新 count 数组
count[arr[p].id] += j - mid - 1;
}
}
}
}
// 注意:javascript 代码由 chatGPT🤖 根据我的 java 代码翻译,旨在帮助不同背景的读者理解算法逻辑。
// 本代码还未经过力扣测试,仅供参考,如有疑惑,可以参照我写的 java 代码对比查看。
var countSmaller = function(nums) {
class Pair {
constructor(val, id) {
// 记录数组的元素值
this.val = val;
// 记录元素在数组中的原始索引
this.id = id;
}
}
// 归并排序所用的辅助数组
var temp;
// 记录每个元素后面比自己小的元素个数
var count;
var sort = function(arr, lo, hi) {
if (lo == hi) return;
var mid = lo + Math.floor((hi - lo) / 2);
sort(arr, lo, mid);
sort(arr, mid + 1, hi);
merge(arr, lo, mid, hi);
};
// 合并两个有序数组
var merge = function(arr, lo, mid, hi) {
for (var i = lo; i <= hi; i++) {
temp[i] = arr[i];
}
var i = lo, j = mid + 1;
for (var p = lo; p <= hi; p++) {
if (i == mid + 1) {
arr[p] = temp[j++];
} else if (j == hi + 1) {
arr[p] = temp[i++];
// 更新 count 数组
count[arr[p].id] += j - mid - 1;
} else if (temp[i].val > temp[j].val) {
arr[p] = temp[j++];
} else {
arr[p] = temp[i++];
// 更新 count 数组
count[arr[p].id] += j - mid - 1;
}
}
};
var n = nums.length;
count = new Array(n).fill(0);
temp = new Array(n);
var arr = new Array(n);
// 记录元素原始的索引位置,以便在 count 数组中更新结果
for (var i = 0; i < n; i++)
arr[i] = new Pair(nums[i], i);
// 执行归并排序,本题结果被记录在 count 数组中
sort(arr, 0, n - 1);
var res = [];
for (var c of count) res.push(c);
return res;
};
因为在排序过程中,每个元素的索引位置会不断改变,所以我们用一个 Pair
类封装每个元素及其在原始数组 nums
中的索引,以便 count
数组记录每个元素之后小于它的元素个数。
接下来我们再看几道原理类似的题目,都是通过给归并排序的 merge
函数加一些私货完成目标。
看一下力扣第 493 题「 翻转对」:
我把这道题换个表述方式,你注意和上一道题目对比:
请你先求出一个 count
数组,其中:
count[i] = COUNT(j) where j > i and nums[i] > 2*nums[j]
然后请你求出这个 count
数组中所有元素的和。
你看,这样说其实和题目是一个意思,而且和上一道题非常类似,只不过上一题求的是 nums[i] > nums[j]
,这里求的是 nums[i] > 2*nums[j]
罢了。
所以解题的思路当然还是要在 merge
函数中做点手脚,当 nums[lo..mid]
和 nums[mid+1..hi]
两个子数组完成排序后,对于 nums[lo..mid]
中的每个元素 nums[i]
,去 nums[mid+1..hi]
中寻找符合条件的 nums[j]
就行了。
看一下我们对 merge
函数的改造:
// 记录「翻转对」的个数
int count = 0;
// 将 nums[lo..mid] 和 nums[mid+1..hi] 这两个有序数组合并成一个有序数组
private void merge(int[] nums, int lo, int mid, int hi) {
for (int i = lo; i <= hi; i++) {
temp[i] = nums[i];
}
// 在合并有序数组之前,加点私货
for (int i = lo; i <= mid; i++) {
// 对于左半边的每个 nums[i],都去右半边寻找符合条件的元素
for (int j = mid + 1; j <= hi; j++) {
// nums 中的元素可能较大,乘 2 可能溢出,所以转化成 long
if ((long)nums[i] > (long)nums[j] * 2) {
count++;
}
}
}
// 数组双指针技巧,合并两个有序数组
int i = lo, j = mid + 1;
for (int p = lo; p <= hi; p++) {
if (i == mid + 1) {
nums[p] = temp[j++];
} else if (j == hi + 1) {
nums[p] = temp[i++];
} else if (temp[i] > temp[j]) {
nums[p] = temp[j++];
} else {
nums[p] = temp[i++];
}
}
}
// 注意:cpp 代码由 chatGPT🤖 根据我的 java 代码翻译,旨在帮助不同背景的读者理解算法逻辑。
// 本代码还未经过力扣测试,仅供参考,如有疑惑,可以参照我写的 java 代码对比查看。
// 记录「翻转对」的个数
int count = 0;
// 将 nums[lo..mid] 和 nums[mid+1..hi] 这两个有序数组合并成一个有序数组
void merge(vector<int>& nums, int lo, int mid, int hi) {
vector<int> temp(hi - lo + 1);
for (int i = lo; i <= hi; i++) {
temp[i - lo] = nums[i];
}
// 在合并有序数组之前,加点私货
for (int i = lo; i <= mid; i++) {
// 对于左半边的每个 nums[i],都去右半边寻找符合条件的元素
for (int j = mid + 1; j <= hi; j++) {
// nums 中的元素可能较大,乘 2 可能溢出,所以转化成 long long
if ((long long) nums[i] > (long long) nums[j] * 2) {
count++;
}
}
}
// 数组双指针技巧,合并两个有序数组
int i = lo, j = mid + 1, p = 0;
while (i <= mid && j <= hi) {
if (temp[i - lo] > temp[j - lo]) {
nums[lo + p] = temp[j - lo];
j++;
} else {
nums[lo + p] = temp[i - lo];
i++;
}
p++;
}
while (i <= mid) {
nums[lo + p] = temp[i - lo];
p++;
i++;
}
while (j <= hi) {
nums[lo + p] = temp[j - lo];
p++;
j++;
}
}
# 注意:python 代码由 chatGPT🤖 根据我的 java 代码翻译,旨在帮助不同背景的读者理解算法逻辑。
# 本代码还未经过力扣测试,仅供参考,如有疑惑,可以参照我写的 java 代码对比查看。
# 记录「翻转对」的个数
count = 0
# 将 nums[lo..mid] 和 nums[mid+1..hi] 这两个有序数组合并成一个有序数组
def merge(nums: List[int], lo: int, mid: int, hi: int) -> None:
global count
for i in range(lo, hi + 1):
temp[i] = nums[i]
# 在合并有序数组之前,加点私货
for i in range(lo, mid + 1):
# 对于左半边的每个 nums[i],都去右半边寻找符合条件的元素
for j in range(mid + 1, hi + 1):
# nums 中的元素可能较大,乘 2 可能溢出,所以转化成 long
if nums[i] > nums[j] * 2:
count += 1
# 数组双指针技巧,合并两个有序数组
i, j = lo, mid + 1
for p in range(lo, hi + 1):
if i == mid + 1:
nums[p] = temp[j]
j += 1
elif j == hi + 1:
nums[p] = temp[i]
i += 1
elif temp[i] > temp[j]:
nums[p] = temp[j]
j += 1
else:
nums[p] = temp[i]
i += 1
// 注意:go 代码由 chatGPT🤖 根据我的 java 代码翻译,旨在帮助不同背景的读者理解算法逻辑。
// 本代码还未经过力扣测试,仅供参考,如有疑惑,可以参照我写的 java 代码对比查看。
// 记录「翻转对」的个数
var count int
// 将 nums[lo..mid] 和 nums[mid+1..hi] 这两个有序数组合并成一个有序数组
func merge(nums []int, lo int, mid int, hi int) {
for i := lo; i <= hi; i++ {
temp[i] = nums[i]
}
// 在合并有序数组之前,加点私货
for i := lo; i <= mid; i++ {
// 对于左半边的每个 nums[i],都去右半边寻找符合条件的元素
for j := mid + 1; j <= hi; j++ {
// nums 中的元素可能较大,乘 2 可能溢出,所以转化成 long
if int64(nums[i]) > int64(nums[j])*2 {
count++
}
}
}
// 数组双指针技巧,合并两个有序数组
i, j := lo, mid+1
for p := lo; p <= hi; p++ {
if i == mid+1 {
nums[p] = temp[j]
j++
} else if j == hi+1 {
nums[p] = temp[i]
i++
} else if temp[i] > temp[j] {
nums[p] = temp[j]
j++
} else {
nums[p] = temp[i]
i++
}
}
}
// 注意:javascript 代码由 chatGPT🤖 根据我的 java 代码翻译,旨在帮助不同背景的读者理解算法逻辑。
// 本代码还未经过力扣测试,仅供参考,如有疑惑,可以参照我写的 java 代码对比查看。
// 记录「翻转对」的个数
var count = 0;
// 将 nums[lo..mid] 和 nums[mid+1..hi] 这两个有序数组合并成一个有序数组
function merge(nums, lo, mid, hi) {
for (var i = lo; i <= hi; i++) {
temp[i] = nums[i];
}
// 在合并有序数组之前,加点私货
for (var i = lo; i <= mid; i++) {
// 对于左半边的每个 nums[i],都去右半边寻找符合条件的元素
for (var j = mid + 1; j <= hi; j++) {
// nums 中的元素可能较大,乘 2 可能溢出,所以转化成 long
if (nums[i] > nums[j] * 2) {
count++;
}
}
}
// 数组双指针技巧,合并两个有序数组
var i = lo, j = mid + 1;
for (var p = lo; p <= hi; p++) {
if (i == mid + 1) {
nums[p] = temp[j++];
} else if (j == hi + 1) {
nums[p] = temp[i++];
} else if (temp[i] > temp[j]) {
nums[p] = temp[j++];
} else {
nums[p] = temp[i++];
}
}
}
不过呢,这段代码提交会超时,毕竟额外添加了一个嵌套 for 循环。怎么进行优化呢,注意子数组 nums[lo..mid]
是排好序的,也就是 nums[i] <= nums[i+1]
。
所以,对于 nums[i], lo <= i <= mid
,我们在找到的符合 nums[i] > 2*nums[j]
的 nums[j], mid+1 <= j <= hi
,也必然也符合 nums[i+1] > 2*nums[j]
。
换句话说,我们不用每次都傻乎乎地去遍历整个 nums[mid+1..hi]
,只要维护一个开区间边界 end
,维护 nums[mid+1..end-1]
是符合条件的元素即可。
看最终的解法代码:
class Solution {
public int reversePairs(int[] nums) {
// 执行归并排序
sort(nums);
return count;
}
private int[] temp;
public void sort(int[] nums) {
temp = new int[nums.length];
sort(nums, 0, nums.length - 1);
}
// 归并排序
private void sort(int[] nums, int lo, int hi) {
if (lo == hi) {
return;
}
int mid = lo + (hi - lo) / 2;
sort(nums, lo, mid);
sort(nums, mid + 1, hi);
merge(nums, lo, mid, hi);
}
// 记录「翻转对」的个数
private int count = 0;
private void merge(int[] nums, int lo, int mid, int hi) {
for (int i = lo; i <= hi; i++) {
temp[i] = nums[i];
}
// 进行效率优化,维护左闭右开区间 [mid+1, end) 中的元素乘 2 小于 nums[i]
// 为什么 end 是开区间?因为这样的话可以保证初始区间 [mid+1, mid+1) 是一个空区间
int end = mid + 1;
for (int i = lo; i <= mid; i++) {
// nums 中的元素可能较大,乘 2 可能溢出,所以转化成 long
while (end <= hi && (long)nums[i] > (long)nums[end] * 2) {
end++;
}
count += end - (mid + 1);
}
// 数组双指针技巧,合并两个有序数组
int i = lo, j = mid + 1;
for (int p = lo; p <= hi; p++) {
if (i == mid + 1) {
nums[p] = temp[j++];
} else if (j == hi + 1) {
nums[p] = temp[i++];
} else if (temp[i] > temp[j]) {
nums[p] = temp[j++];
} else {
nums[p] = temp[i++];
}
}
}
}
// 注意:cpp 代码由 chatGPT🤖 根据我的 java 代码翻译,旨在帮助不同背景的读者理解算法逻辑。
// 本代码还未经过力扣测试,仅供参考,如有疑惑,可以参照我写的 java 代码对比查看。
class Solution {
public:
void sort(vector<int>& nums, int lo, int hi);
int reversePairs(vector<int>& nums) {
// 执行归并排序
sort(nums, 0, nums.size()-1);
return count;
}
private:
vector<int> temp;
// 归并排序
void sort(vector<int>& nums, int lo, int hi) {
if (lo == hi) {
return;
}
int mid = lo + (hi - lo) / 2;
sort(nums, lo, mid);
sort(nums, mid + 1, hi);
merge(nums, lo, mid, hi);
}
// 记录「翻转对」的个数
int count = 0;
void merge(vector<int>& nums, int lo, int mid, int hi) {
for (int i = lo; i <= hi; i++) {
temp[i] = nums[i];
}
// 进行效率优化,维护左闭右开区间 [mid+1, end) 中的元素乘 2 小于 nums[i]
// 为什么 end 是开区间?因为这样的话可以保证初始区间 [mid+1, mid+1) 是一个空区间
int end = mid + 1;
for (int i = lo; i <= mid; i++) {
// nums 中的元素可能较大,乘 2 可能溢出,所以转化成 long
while (end <= hi && (long)nums[i] > (long)nums[end] * 2) {
end++;
}
count += (end - (mid + 1));
}
// 数组双指针技巧,合并两个有序数组
int i = lo, j = mid + 1;
for (int p = lo; p <= hi; p++) {
if (i == mid + 1) {
nums[p] = temp[j++];
} else if (j == hi + 1) {
nums[p] = temp[i++];
} else if (temp[i] > temp[j]) {
nums[p] = temp[j++];
} else {
nums[p] = temp[i++];
}
}
}
};
# 注意:python 代码由 chatGPT🤖 根据我的 java 代码翻译,旨在帮助不同背景的读者理解算法逻辑。
# 本代码还未经过力扣测试,仅供参考,如有疑惑,可以参照我写的 java 代码对比查看。
class Solution:
def reversePairs(self, nums: List[int]) -> int:
self.count = 0
self.__sort(nums)
return self.count
def __sort(self, nums):
self.temp = [0] * len(nums)
self.__sort_helper(nums, 0, len(nums)-1)
# 归并排序
def __sort_helper(self, nums, lo, hi):
if lo >= hi:
return
mid = (lo + hi) // 2
self.__sort_helper(nums, lo, mid)
self.__sort_helper(nums, mid+1, hi)
self.__merge(nums, lo, mid, hi)
# 数组双指针技巧,合并两个有序数组
def __merge(self, nums, lo, mid, hi):
for i in range(lo, hi+1):
self.temp[i] = nums[i]
# 进行效率优化,维护左闭右开区间 [mid+1, end) 中的元素乘 2 小于 nums[i]
# 为什么 end 是开区间?因为这样的话可以保证初始区间 [mid+1, mid+1) 是一个空区间
end = mid + 1
for i in range(lo, mid+1):
# nums 中的元素可能较大,乘 2 可能溢出,所以转化成 long
while end <= hi and (long)nums[i] > (long)nums[end] * 2:
end += 1
self.count += end - (mid + 1)
i, j = lo, mid+1
for k in range(lo, hi+1):
if i > mid:
nums[k] = self.temp[j]
j += 1
elif j > hi:
nums[k] = self.temp[i]
i += 1
elif self.temp[i] > self.temp[j]:
nums[k] = self.temp[j]
j += 1
else:
nums[k] = self.temp[i]
i += 1
// 注意:go 代码由 chatGPT🤖 根据我的 java 代码翻译,旨在帮助不同背景的读者理解算法逻辑。
// 本代码还未经过力扣测试,仅供参考,如有疑惑,可以参照我写的 java 代码对比查看。
func reversePairs(nums []int) int {
temp := make([]int, len(nums))
count := 0
// 执行归并排序
var sort func([]int, int, int)
sort = func(nums []int, lo int, hi int) {
if lo == hi {
return
}
mid := lo + (hi - lo) / 2
sort(nums, lo, mid)
sort(nums, mid + 1, hi)
merge(nums, temp, lo, mid, hi, &count)
}
sort(nums, 0, len(nums) - 1)
return count
}
func merge(nums []int, temp []int, lo int, mid int, hi int, count *int) {
for i := lo; i <= hi; i++ {
temp[i] = nums[i]
}
end := mid + 1
for i := lo; i <= mid; i++ {
for end <= hi && int64(nums[i]) > int64(nums[end]) * 2 {
end++
}
*count += end - (mid + 1)
}
i, j := lo, mid + 1
for p := lo; p <= hi; p++ {
if i == mid + 1 {
nums[p] = temp[j]
j++
} else if j == hi + 1 {
nums[p] = temp[i]
i++
} else if temp[i] > temp[j] {
nums[p] = temp[j]
j++
} else {
nums[p] = temp[i]
i++
}
}
}
// 注意:javascript 代码由 chatGPT🤖 根据我的 java 代码翻译,旨在帮助不同背景的读者理解算法逻辑。
// 本代码还未经过力扣测试,仅供参考,如有疑惑,可以参照我写的 java 代码对比查看。
// 您应该保留中文注释,不应该修改它们
function reversePairs(nums) {
let count = 0;
// 递归的目的:将 nums 左右区间排好序,使得 nums 左区间中的数都在右区间中的数的前面
function sortHelper(nums, lo, hi) {
if (lo == hi) {
return;
}
let mid = lo + Math.floor((hi - lo) / 2);
sortHelper(nums, lo, mid);
sortHelper(nums, mid + 1, hi);
merge(nums, lo, mid, hi);
}
// 合并左右区间
function merge(nums, lo, mid, hi) {
// 用 temp 保存一下原始区间,方便修改
let temp = new Array(nums.length).fill(0);
for (let i = lo; i <= hi; i++) {
temp[i] = nums[i];
}
let start = lo, end = mid + 1;
for (let i = lo; i <= hi; i++) {
if (start == mid + 1) {
nums[i] = temp[end++];
} else if (end == hi + 1) {
nums[i] = temp[start++];
} else if (temp[start] <= temp[end]) {
nums[i] = temp[start++];
} else {
nums[i] = temp[end++];
count += mid - start + 1;
}
}
}
// 通过递归,不断地将左右区间归并排序,然后每个小区间都是有序的
sortHelper(nums, 0, nums.length - 1);
return count;
}
如果你能够理解这道题目,我们最后来看一道难度更大的题目,力扣第 327 题「 区间和的个数」:
简单说,题目让你计算元素和落在 [lower, upper]
中的所有子数组的个数。
拍脑袋的暴力解法我就不说了,依然是嵌套 for 循环,这里还是说利用归并排序实现的高效算法。
首先,解决这道题需要快速计算子数组的和,所以你需要阅读前文
前缀和数组技巧,创建一个前缀和数组 preSum
来辅助我们迅速计算区间和。
我继续用比较数学的语言来表述下这道题,题目让你通过 preSum
数组求一个 count
数组,使得:
count[i] = COUNT(j) where lower <= preSum[j] - preSum[i] <= upper
然后请你求出这个 count
数组中所有元素的和。
你看,这是不是和题目描述一样?preSum
中的两个元素之差其实就是区间和。
有了之前两道题的铺垫,我直接给出这道题的解法代码吧,思路见注释:
class Solution {
private int lower, upper;
public int countRangeSum(int[] nums, int lower, int upper) {
this.lower = lower;
this.upper = upper;
// 构建前缀和数组,注意 int 可能溢出,用 long 存储
long[] preSum = new long[nums.length + 1];
for (int i = 0; i < nums.length; i++) {
preSum[i + 1] = (long)nums[i] + preSum[i];
}
// 对前缀和数组进行归并排序
sort(preSum);
return count;
}
private long[] temp;
public void sort(long[] nums) {
temp = new long[nums.length];
sort(nums, 0, nums.length - 1);
}
private void sort(long[] nums, int lo, int hi) {
if (lo == hi) {
return;
}
int mid = lo + (hi - lo) / 2;
sort(nums, lo, mid);
sort(nums, mid + 1, hi);
merge(nums, lo, mid, hi);
}
private int count = 0;
private void merge(long[] nums, int lo, int mid, int hi) {
for (int i = lo; i <= hi; i++) {
temp[i] = nums[i];
}
// 在合并有序数组之前加点私货(这段代码会超时)
// for (int i = lo; i <= mid; i++) {
// for (int j = mid + 1; j <= hi; k++) {
// // 寻找符合条件的 nums[j]
// long delta = nums[j] - nums[i];
// if (delta <= upper && delta >= lower) {
// count++;
// }
// }
// }
// 进行效率优化
// 维护左闭右开区间 [start, end) 中的元素和 nums[i] 的差在 [lower, upper] 中
int start = mid + 1, end = mid + 1;
for (int i = lo; i <= mid; i++) {
// 如果 nums[i] 对应的区间是 [start, end),
// 那么 nums[i+1] 对应的区间一定会整体右移,类似滑动窗口
while (start <= hi && nums[start] - nums[i] < lower) {
start++;
}
while (end <= hi && nums[end] - nums[i] <= upper) {
end++;
}
count += end - start;
}
// 数组双指针技巧,合并两个有序数组
int i = lo, j = mid + 1;
for (int p = lo; p <= hi; p++) {
if (i == mid + 1) {
nums[p] = temp[j++];
} else if (j == hi + 1) {
nums[p] = temp[i++];
} else if (temp[i] > temp[j]) {
nums[p] = temp[j++];
} else {
nums[p] = temp[i++];
}
}
}
}
// 注意:cpp 代码由 chatGPT🤖 根据我的 java 代码翻译,旨在帮助不同背景的读者理解算法逻辑。
// 本代码还未经过力扣测试,仅供参考,如有疑惑,可以参照我写的 java 代码对比查看。
class Solution {
private:
int lower, upper;
public:
int countRangeSum(vector<int>& nums, int lower, int upper) {
this->lower = lower;
this->upper = upper;
int n = nums.size();
// 构建前缀和数组,注意 int 可能溢出,用 long long 存储
vector<long long> preSum(n + 1);
for (int i = 0; i < n; i++) {
preSum[i + 1] = nums[i] + preSum[i];
}
// 对前缀和数组进行归并排序
sort(preSum.begin(), preSum.end());
int count = 0;
vector<long long> temp(n + 1);
mergeSort(preSum, temp, 0, n, count);
return count;
}
void mergeSort(vector<long long>& nums, vector<long long>& temp, int left, int right, int& count) {
if (left >= right) {
return;
}
int mid = left + (right - left) / 2;
mergeSort(nums, temp, left, mid, count);
mergeSort(nums, temp, mid + 1, right, count);
// 维护左闭右开区间 [start, end) 中的元素和 nums[i] 的差在 [lower, upper] 中
int start = mid + 1, end = mid + 1;
for (int i = left; i <= mid; i++) {
while (start <= right && nums[start] - nums[i] < lower) {
start++;
}
while (end <= right && nums[end] - nums[i] <= upper) {
end++;
}
count += end - start;
}
// 数组双指针技巧,合并两个有序数组
int i = left, j = mid + 1, p = left;
while (i <= mid && j <= right) {
if (nums[i] <= nums[j]) {
temp[p++] = nums[i++];
} else {
temp[p++] = nums[j++];
}
}
while (i <= mid) {
temp[p++] = nums[i++];
}
while (j <= right) {
temp[p++] = nums[j++];
}
copy(temp.begin() + left, temp.begin() + right + 1, nums.begin() + left);
}
};
# 注意:python 代码由 chatGPT🤖 根据我的 java 代码翻译,旨在帮助不同背景的读者理解算法逻辑。
# 本代码还未经过力扣测试,仅供参考,如有疑惑,可以参照我写的 java 代码对比查看。
class Solution:
def countRangeSum(self, nums: List[int], lower: int, upper: int) -> int:
self.lower = lower
self.upper = upper
# 构建前缀和数组,注意 int 可能溢出,用 long 存储
preSum = [0] * (len(nums) + 1)
for i in range(len(nums)):
preSum[i + 1] = nums[i] + preSum[i]
# 对前缀和数组进行归并排序
self.sort(preSum)
return self.count
def sort(self, nums: List[int]):
self.temp = [0] * len(nums)
self.__sort(nums, 0, len(nums) - 1)
def __sort(self, nums: List[int], lo: int, hi: int):
if lo == hi:
return
mid = lo + (hi - lo) // 2
self.__sort(nums, lo, mid)
self.__sort(nums, mid + 1, hi)
self.__merge(nums, lo, mid, hi)
def __merge(self, nums: List[int], lo: int, mid: int, hi: int):
for i in range(lo, hi + 1):
self.temp[i] = nums[i]
# 在合并有序数组之前加点私货(这段代码会超时)
# for (int i = lo; i <= mid; i++) {
# for (int j = mid + 1; j <= hi; k++) {
# // 寻找符合条件的 nums[j]
# long delta = nums[j] - nums[i];
# if (delta <= upper && delta >= lower) {
# count++;
# }
# }
# }
# 进行效率优化
# 维护左闭右开区间 [start, end) 中的元素和 nums[i] 的差在 [lower, upper] 中
start = end = mid + 1
for i in range(lo, mid + 1):
# 如果 nums[i] 对应的区间是 [start, end),
# 那么 nums[i+1] 对应的区间一定会整体右移,类似滑动窗口
while start <= hi and nums[start] - nums[i] < self.lower:
start += 1
while end <= hi and nums[end] - nums[i] <= self.upper:
end += 1
self.count += end - start
# 数组双指针技巧,合并两个有序数组
i, j = lo, mid + 1
for p in range(lo, hi + 1):
if i == mid + 1:
nums[p] = self.temp[j]
j += 1
elif j == hi + 1:
nums[p] = self.temp[i]
i += 1
elif self.temp[i] > self.temp[j]:
nums[p] = self.temp[j]
j += 1
else:
nums[p] = self.temp[i]
i += 1
// 注意:go 代码由 chatGPT🤖 根据我的 java 代码翻译,旨在帮助不同背景的读者理解算法逻辑。
// 本代码还未经过力扣测试,仅供参考,如有疑惑,可以参照我写的 java 代码对比查看。
func countRangeSum(nums []int, lower int, upper int) int {
count := 0
sortHelper := func(nums []int64, lo int, hi int) {
if lo == hi {
return
}
mid := lo + (hi-lo)/2
sortHelper(nums, lo, mid)
sortHelper(nums, mid+1, hi)
merge(nums, lo, mid, hi)
}
merge := func(nums []int64, lo int, mid int, hi int) {
temp := make([]int64, len(nums))
for i := lo; i <= hi; i++ {
temp[i] = nums[i]
}
// 进行效率优化
// 维护左闭右开区间 [start, end) 中的元素和 nums[i] 的差在 [lower, upper] 中
start := mid + 1
end := mid + 1
for i := lo; i <= mid; i++ {
for start <= hi && nums[start]-nums[i] < int64(lower) {
start++
}
for end <= hi && nums[end]-nums[i] <= int64(upper) {
end++
}
count += end - start
}
// 数组双指针技巧,合并两个有序数组
i := lo
j := mid + 1
for p := lo; p <= hi; p++ {
if i == mid+1 {
nums[p] = temp[j]
j++
} else if j == hi+1 {
nums[p] = temp[i]
i++
} else if temp[i] > temp[j] {
nums[p] = temp[j]
j++
} else {
nums[p] = temp[i]
i++
}
}
}
// 构建前缀和数组,注意 int 可能溢出,用 long 存储
preSum := make([]int64, len(nums)+1)
for i := 0; i < len(nums); i++ {
preSum[i+1] = int64(nums[i]) + preSum[i]
}
// 对前缀和数组进行归并排序
sortHelper(preSum, 0, len(nums))
return count
}
// 注意:javascript 代码由 chatGPT🤖 根据我的 java 代码翻译,旨在帮助不同背景的读者理解算法逻辑。
// 本代码还未经过力扣测试,仅供参考,如有疑惑,可以参照我写的 java 代码对比查看。
var countRangeSum = function(nums, lower, upper) {
this.lower = lower;
this.upper = upper;
// 构建前缀和数组,注意 int 可能溢出,用 long 存储
var preSum = new Array(nums.length + 1);
preSum.fill(0);
for (var i = 0; i < nums.length; i++) {
preSum[i + 1] = nums[i] + preSum[i];
}
// 对前缀和数组进行归并排序
this.sort(preSum);
return this.count;
};
Solution.prototype.sort = function(nums) {
this.temp = new Array(nums.length);
this.sortHelper(nums, 0, nums.length - 1);
};
Solution.prototype.sortHelper = function(nums, lo, hi) {
if (lo === hi) {
return;
}
var mid = lo + Math.floor((hi - lo) / 2);
this.sortHelper(nums, lo, mid);
this.sortHelper(nums, mid + 1, hi);
this.merge(nums, lo, mid, hi);
};
Solution.prototype.merge = function(nums, lo, mid, hi) {
for (var i = lo; i <= hi; i++) {
this.temp[i] = nums[i];
}
// 在合并有序数组之前加点私货(这段代码会超时)
// for (var i = lo; i <= mid; i++) {
// for (var j = mid + 1; j <= hi; k++) {
// // 寻找符合条件的 nums[j]
// var delta = nums[j] - nums[i];
// if (delta <= this.upper && delta >= this.lower) {
// this.count++;
// }
// }
// }
// 进行效率优化
// 维护左闭右开区间 [start, end) 中的元素和 nums[i] 的差在 [lower, upper] 中
var start = mid + 1, end = mid + 1;
for (var i = lo; i <= mid; i++) {
// 如果 nums[i] 对应的区间是 [start, end),
// 那么 nums[i+1] 对应的区间一定会整体右移,类似滑动窗口
while (start <= hi && nums[start] - nums[i] < this.lower) {
start++;
}
while (end <= hi && nums[end] - nums[i] <= this.upper) {
end++;
}
this.count += end - start;
}
// 数组双指针技巧,合并两个有序数组
var i = lo, j = mid + 1;
for (var p = lo; p <= hi; p++) {
if (i === mid + 1) {
nums[p] = this.temp[j++];
} else if (j === hi + 1) {
nums[p] = this.temp[i++];
} else if (this.temp[i] > this.temp[j]) {
nums[p] = this.temp[j++];
} else {
nums[p] = this.temp[i++];
}
}
};
我们依然在 merge
函数合并有序数组之前加了一些逻辑,如果看过前文
滑动窗口核心框架,这个效率优化有点类似维护一个滑动窗口,让窗口中的元素和 nums[i]
的差落在 [lower, upper]
中。
归并排序相关的题目到这里就讲完了,你现在回头体会下我在本文开头说那句话:
所有递归的算法,本质上都是在遍历一棵(递归)树,然后在节点(前中后序位置)上执行代码。你要写递归算法,本质上就是要告诉每个节点需要做什么。
比如本文讲的归并排序算法,递归的 sort
函数就是二叉树的遍历函数,而 merge
函数就是在每个节点上做的事情,有没有品出点味道?
最后总结一下吧,本文从二叉树的角度讲了归并排序的核心思路和代码实现,同时讲了几道归并排序相关的算法题。这些算法题其实就是归并排序算法逻辑中夹杂一点私货,但仍然属于比较难的,你可能需要亲自做一遍才能理解。
那我最后留一个思考题吧,下一篇文章我会讲快速排序,你是否能够尝试着从二叉树的角度去理解快速排序?如果让你用一句话总结快速排序的逻辑,你怎么描述?
好了,答案在下篇文章 快速排序详解及应用 揭晓。
_____________
《labuladong 的算法小抄》已经出版,关注公众号查看详情;后台回复关键词「进群」可加入算法群;回复「全家桶」可下载配套 PDF 和刷题全家桶:
共同维护高质量学习环境,评论礼仪见这里,违者直接拉黑不解释