1340 - Jump Games V
题干
破题
这个题的描述非常别扭,啰嗦还不得要领,配得示例图也抓不到展示的重点,让人很难进入这个题目,单就这个题目文本的编写我就要给出差评!
简而言之,还是跳一个范围 $d$,双向都可以,但是只能从高跳到低,包括中途经过的位置也要比起点低。
数据限制:
1 <= arr.length <= 1000
1 <= arr[i] <= 105
1 <= d <= arr.length
解①原始解:
开始观察数据的特点,发现 $d$ 的大小是整个给定数组 $\text{arr}$ 同级别的,而 $\text{arr}$ 的规模非常小,只有 $10^3$ ,这暗示了 $O(n^2)$ 的解决方法可能不是最优解,但应该足以通过。
在思考这个题的时候,由于没见过这个跳的方法,题目说明得又很蹩脚,头脑混沌不清,充斥着各种可能地思路,又看见了本题有一个排序的标签,觉得排序很可能在最优解里有扮演关键角色,于是构思出了一个相当蹩脚的原始实现,后来证明了大概是能通过的最慢解了。
思路如下:
- 先暴力找到每个位置左右能跳的范围,可以分别用 $\text{forward}$ 和 $\text{backward}$ 两个数组进行保存,花费 $O(d\cdot n)$;
- 对 $\text{arr}$ 进行排序,花费 $O(n\text{log}n)$ ;
- 对排序后的 $\text{arr}$ 进行遍历,因为只能从高到地跳,因此每个位置 $i$ 的下一跳的位置一定在它的前面,于是暴力地检查前面的每个位置,如果它的原始位置是能够一步跳到的,就进行比较,更新 $i$ 处索引最大能跳到的位置,花费 $O(n^2)$
总的时间复杂度 $O(n^2)$
def solve(arr: List[int], d: int) -> int:
n = len(arr)
forward = [0] * n
backward = [0] * n
for i in range(0, n):
r = 0
for j in range(i+1, min(i+d+1, n)):
if arr[i] <= arr[j]:
break
r += 1
forward[i] = r
r = 0
for j in range(i-1, max(0, i-d)-1, -1):
if arr[i] <= arr[j]:
break
r += 1
backward[i] = r
data = sorted(enumerate(arr), key=lambda x: x[1])
cache = [1] * n
for i in range(1, n):
for j in range(i-1, -1, -1):
if data[i][1] == data[j][1]:
continue
if data[j][0] < data[i][0]:
if data[i][0] - data[j][0] <= backward[data[i][0]]:
cache[i] = max(cache[i], cache[j] + 1)
else:
if data[j][0] - data[i][0] <= forward[data[i][0]]:
cache[i] = max(cache[i], cache[j] + 1)
return max(cache)
运行时间:4348 ms (beats 5.24%), 内存占用:16.66 MB (90.32%)
解②记忆化:
这个解是官方给出的,也是最常见的正解:
就是直接在原序列上计算,同样地暴力查找 $d$ 范围内符合条件的下一步位置,保存每个位置的计算结果,像这样时间复杂度也只有 $O(d\cdot n)$ ,比我们之前的方法 $O(n^2)$ 的复杂度要好,即使明明我们的思路看起来要聪明一些。1
这个思路能够成立是因为我们总是从高跳到低,不存在回路。
如果按照 DFS 的顺序计算,非常适合用递归地形式简单解决,配合上 Python 的函数式工具 cache
缓存所有同参数的调用,代码非常简洁:
from typing import List
from itertools import chain, takewhile
from functools import cache
def solve(arr: List[int], d: int) -> int:
n = len(arr)
@cache
def dfs(i: int) -> int:
subs = chain(
takewhile(lambda j: arr[j] < arr[i],
range(i-1, max(0, i-d)-1, -1)),
takewhile(lambda j: arr[j] < arr[i],
range(i+1, min(n-1, i+d)+1, 1)),
)
return max(map(dfs, subs), default=0) + 1
return max(map(dfs, range(n))) # n >= 1
运行时间:351 ms (beats 88.49%) ,内存占用:20.99 MB (beats 17.46%)
概念解剖
这里提到了“记忆化搜索”的概念,它是一个单独的分类,但从实质上讲也可以认为是 DP 的一类特定情况,这里就不特别区分二者了。
语言责难
在我们当前的环境,Python 3.10 ,如果不使用 cache
而是手动保存结果的话,如下所示:
# ...
def solve(arr: List[int], d: int) -> int:
n = len(arr)
cache = [0] * n
def dfs(i: int) -> int:
if cache[i]:
return cache[i]
subs = chain(
takewhile(lambda j: arr[j] < arr[i], range(i-1, max(0, i-d)-1, -1)),
takewhile(lambda j: arr[j] < arr[i], range(i+1, min(n-1, i+d)+1, 1)),
)
cache[i] = max(map(dfs, subs), default=0) + 1
return cache[i]
return max(map(dfs, range(n)))
内存占用不出所料地降低到了:19.17 MB (beats 54.37%) ,但是运行时间也提高到了: 421 ms (beats 80.95%) ,性能损失达到了 20% ,换言之,使用 cache
虽然增加了内存占用,但相比手动缓存,性能有明显提升。
如果说 cache
作为 functools
这个函数式编程辅助模块的一个代表2,表现还是不错的,那另一个函数式编程的辅助模块 itertools
的表现就有些拉胯了:这些迭代器上的函数的性能可观察地不如普通地 for
循环。
比如把上面的
subs = chain(
takewhile(lambda j: arr[j] < arr[i], range(i-1, max(0, i-d)-1, -1)),
takewhile(lambda j: arr[j] < arr[i], range(i+1, min(n-1, i+d)+1, 1)),
)
替换成
subs = []
for j in range(i-1, max(0, i-d)-1, -1):
if arr[j] < arr[i]:
subs.append(j)
else:
break
for j in range(i+1, min(n-1, i+d)+1, 1):
if arr[j] < arr[i]:
subs.append(j)
else:
break
就有有 10% 级别的性能提升,这在像 Rust 这样的语言里是难以想象的!
解③排序+反向构建:
继续考虑上面的记忆化方法,DFS 的搜索顺序配合递归地实现只是说是一种实现方案,它很简洁很容易写,因此通常我们默认采用这种方法,但实际上当然也可以采用 BFS 的顺序以及迭代地实现,区别只是:
DFS 的迭代版本需要手动保存每一级的迭代器状态,而 BFS 需要保存某一级的所有孩子(当然还有该级的级数,也就是之前已经跳了多少步),不管怎样,实际问题地解决总是要到没有后代的节点,才能解决。
如果一个节点本身就是最小的高度,那么它一定没有子节点,或者周围 $d$ 范围内都是不低于它的节点,或者倒过来想,如果知道了所有高度严格小于某个节点的的节点信息(最大访问数、前向范围和后向范围),那么该节点的信息也可以利用这些信息3来直接得到:
-
最大访问数是前一级节点的最大访问数加一,然后取它们的最大值;
-
前一级节点的前向(顺扫描的方向)或后向(逆扫描的方向)范围,以后向为例,先选取后向的第一个元素,只要它小于当前元素,那么在不超过 $d$ 的范围,就可以直接把它的后向距离加入到当前位置的后向距离,然后从那个距离处继续向后扫描,直到超出 $d$ 的范围或者发现不小于当前元素的元素
有必要特别说明下这个前向或后向的范围,缓存这个值,可以减少位置值的比较次数3,虽然在求解最大访问数的时候复杂度仍然是 $O(d\cdot n)$ ,但实际证明这个做法让性能提高了 40% 以上。
def solve(arr: List[int], d: int) -> int:
n = len(arr)
ans_cache = [1] * n
forward_cache = [0] * n
backward_cache = [0] * n
data = sorted(enumerate(arr), key=lambda x: x[1])
for i, v in data:
# d >= 1
r = 0
while i+r+1 < n and r < d and arr[i+r+1] < v:
r += forward_cache[i+r+1] + 1
forward_cache[i] = min(r, d)
l = 0
while i-l > 0 and l < d and arr[i-l-1] < v:
l += backward_cache[i-l-1] + 1
backward_cache[i] = min(l, d)
ans_cache[i] += max(chain(ans_cache[i-l:i], ans_cache[i+1:i+r+1]), default=0)
return max(ans_cache)
这个方法本质上是记忆化方法的改进,省去了堆栈的开销,时间复杂度是 Better $O(d\cdot n)$ 。
运行时间:199 ms (beats 88.3%),内存占用:16.9 MB (beats 73.24%)
实际结果看相当好了,与最优解相差不大了。
解➃贪心+记忆化:
贪心思想地应用是时间复杂度最终能降低到 $O(n)$ 的关键4。
如之前 Jump Games 所示,$d$ 既然是连续的,就有贪心算法应用的空间,只不过还有一个限制条件,高度也要更低。
从反方向来考虑那些最长访问次数的位置序列,每个位置 $i$ 的前一个位置一定是距离它最近的更高的位置(当然左右方向分别考虑),否则如果不是最近的高度,要么从这个高度要么跳不到 $i$ ,因为中间还有不低于这个高度的位置,要么可以先跳到中间的一个高度的位置,然后跳到 $i$ ,总之都不是 $i$ 的直接前驱。
于是可以在 $O(n)$ 的时间复杂内找到每个位置的可能的直接前驱(分前向后向两个方向),然后可以在另一个 $O(n)$ 的复杂度内带记忆化地遍历每个位置,通过它的最多两个前驱,计算该位置的最大访问数,最后取最大值。
from typing import List, Iterator
from functools import cache
def solve(arr: List[int], d: int) -> int:
n = len(arr)
forward = [-1] * n
backward = [-1] * n
def build(range: Iterator[int], storage: List[int]):
stack = []
for i in range:
while stack and arr[stack[-1]] < arr[i] and abs(stack[-1] - i) <= d:
storage[stack.pop()] = i
stack.append(i)
build(range(n), forward)
build(range(n-1, -1, -1), backward)
@cache
def backtrace(i: int) -> int:
if i == -1:
return 0
return max(backtrace(forward[i]), backtrace(backward[i])) + 1
return max(map(backtrace, range(n)))
时间复杂度 $O(n)$ 。
运行时间:152 ms (beats 98.41%),内存占用:19.52 MB (beats 42.46%)
解➄贪心+单调栈+双指针+DP:
前面的 $O(n)$ 复杂度的实现已经够好了,但是还有同复杂度但是更快的方法。
这个方法是从最快的样本代码里扒出来的,这个实现不仅融汇贯通了这些算法思想,而且代码细节上也是 Python 最优,几乎让人无法增删一笔,实在是优秀。
它区别于前一个方法的地方在于它把计算两个方向上的最近更高位置和求解最大访问次数能够放在一趟遍历过程中同时解决!
它的方法是这样:
- 使用单调栈至少可以计算后向的最近更高位置,而利用出栈的元素,又可以更新栈里前向的最近更高位置;
- 实际上不需要维护最近更高位置,可以直接更新某个位置的最大访问次数
具体过程:
扫描输入数组 $\text{arr}$ ,只要栈顶元素小于当前元素,就出栈;同时每次出栈时都把连带着所有与栈顶相等元素一同出栈,除了用栈顶元素从后向更新当前位置的最大访问数,还要利用所有连同出栈的相等的元素来前向更新新的栈顶元素的最大访问数(由于是单调不减栈,而且与旧栈顶相等的元素都出栈了,所以新栈顶一定严格大于出栈的所有元素)。
这样有一个问题,当扫描完成时,栈里的元素不一定都全部出栈了,因为最后一个元素不一定是最大的元素,这时要像前面一样,继续弹出栈顶元素,来前向地更新栈里后面的元素。
单独再加一段儿处理代码会很啰嗦,更合适地方法是在原数组尾部插入一个确定一定比所有元素都大的值,只是在求最大访问次数时排除尾部就行了。
from typing import List
max_int = 100_000 + 1
def solve(arr: List[int], d: int) -> int:
arr.append(max_int)
n = len(arr)
stack = []
dp = [1] * n
for i in range(n):
while stack and arr[stack[-1]] < arr[i]:
isoheight = [stack.pop()]
while stack and arr[stack[-1]] == arr[isoheight[0]]:
isoheight.append(stack.pop())
for j in isoheight:
if i-j <= d:
dp[i] = max(dp[i], dp[j]+1)
if stack and j-stack[-1] <= d:
dp[stack[-1]] = max(dp[stack[-1]], dp[j]+1)
stack.append(i)
return max(dp[:-1])
时间复杂度 Better $O(n)$
运行时间:120 ms (beats 100%),内存占用:16.6 MB (beats 96.6%)
这段代码不仅从算法方法上,而且从代码编写上,实在已经是 Python 下面最快地写法了,任何比如拆分两个更新,使用迭代器都比这个实现更慢。