0115 - Distinct Subsequences
题干
破题
判断一个串的所有(非连续)子序列里,等于目标串的数量。
这里要特别关注解3:子序列的窗口范式
解①正解:
这个是最直接的推导思路,考虑 $s$ 的前缀从 $s[..0]$ 逐个增加字符直到 $s[..n]$ ,观察对样本串 $t$ 它的前缀随之的变化。
假设增加了一个字符 $x$ ,那么对于样本串 $t$ ,所有出现 $x$ 的位置的前缀都会因此受到影响。
假设 $x$ 出现在 $t$ 的所有位置是 P,则有 $\text{dp}[P] := \text{dp}[P] + \text{dp}[P-1]$,也就是 $x$ 前面的前缀会因为 $x$ 的加入,而使得包含 $x$ 的前缀数量增加。特别地,当 $t$ 的前缀是空串的时候,设定值总是为 $1$ ,这样可以方便统一地计算。
那么该如何计算出 $x$ 出现在 $t$ 的所有位置呢,正好字符集很小(大小写字母),直接用列表保存最节省内存,可以在对 $t$ 的一遍扫描过程中把坐标加进对应字符的坐标列表里。
于是得到了下面的实现:
def solve(s: str, t: str) -> int:
if len(s) < len(t):
return 0
if set(s) < set(t):
return 0
pos = [[] for _ in range(26 * 2)]
for i, c in enumerate(t):
key = ord(c) - ord("a") if ord(c) >= ord("a") else ord(c) - ord("A") + 26
pos[key].append(i + 1)
rec = [0] * (len(t) + 1)
for c_i, c in enumerate(s):
pre_v = 1
pre_i = 0
key = ord(c) - ord("a") if ord(c) >= ord("a") else ord(c) - ord("A") + 26
for i in pos[key]:
if i > c_i+1:
break
tmp = rec[i]
if i - 1 == pre_i:
rec[i] += pre_v
else:
rec[i] += rec[i - 1]
pre_v = tmp
pre_i = i
return rec[-1]
有几点需要特别说明的:
1. 快速失败地优化
譬如
if len(s) < len(t):
return 0
if set(s) < set(t):
return 0
这样的代码,几乎总是可以明显提高我们的运行表现,为一些应该很优,但由于 Python 匮乏地运行时优化而表现一般地实现挽尊:—— 。
2. DP 数组的特异处
从代码上乍看起来,这里的 $\text{dp}$ 数组和之前那些一般的 $\text{dp}$ 数组是一样的,我们只是照样把二维压缩到一维。
但这里特别地是,从计算过程上看,每一行并不是所有的数据都会被更新,没有被计算的数据应该继承自上一行,而不是被置为一个固定的初始值。
这样使用二维不仅浪费空间,而且并不方便,使用一维的空间,总是更好的选择。
3. 一维地实现
同样地,在使用压缩地 $\text{dp}$ 数组时,为了避免数据被覆盖,需要提前保存上一行的前一个位置的数据,但是我们预计算的位置并不一定是连续的,因此需要进行判断。对于不连续的位置,就不会存在数据覆盖的问题。
运行时间 94 ms (beats 92.81%),内存占用:16.46 MB (beats 96.82%, ~100%) 。
解②递归+迭代:
实际上如果我们理解了解①,那么可以把它转换成我们更熟悉地 $\text{dp}$ 实现。
只要考虑:
- 要么当 $s$ 的前缀拓展一个字符时,$t$ 也拓展一个字符,这只当 s 和 t 拓展的字符相等时才可以;
- 要么无条件地跳过当前字符,也就是 $s$ 拓展字符,但是 $t$ 不拓展
这样当:
- $s$ 的长度小于 $t$ 时,当然返回 $0$;
- 而当 $s$ 长度等于 $t$ 时,进行比较,返回 $1$ 或者 $0$;
- 当 $m$ 长度为 $0$ 时,也就是递归的一般返回点,或者说初试情况,返回 $1$
递归地版本最简洁,表现也最好:
from functools import cache
@cache
def solve(s: str, t: str) -> int:
n = len(s)
m = len(t)
if n < m:
return 0
elif n == m:
return 1 if s == t else 0
elif m == 0:
return 1
acc = solve(s[1:], t)
if s[0] == t[0]:
acc += solve(s[1:], t[1:])
return acc
运行时间:70 ms (beats 93.39%),内存占用:34.2 MB (beats 72.33%) 。
迭代的版本:
def solve(s: str, t: str) -> int:
n = len(s)
m = len(t)
if n < m:
return 0
rec = [0] * m
for i in range(n):
pre = 1
for j in range(min(i+1, m)):
tmp = rec[j]
if s[i] == t[j]:
rec[j] += pre
pre = tmp
return rec[-1]
运行时间:95 ms (beats 92.78%),内存占用:16.5 MB (beats 96.74%) 。
解③另一个范式:
这里要介绍得是规模递增的另一个范式,很少提到,但是由于有更好地迭代次数,是最快的实现。
对于上面的一般思路来说,是增加 $s$ 的前缀的字符,把问题规模一步步拓展到全局,而对于子序列地匹配,还可以这样考虑:
每一个文本串 $s$ 上面可以匹配目标串 $t$ 的子序列都可以看做,由某些在 $s$ 上的 $t$ 长度模式的对应位置字符构成。
这些构成 $t$ 的子序列的字符可以是连续的,也就是在同一个模式上,也可以不连续,那就是在不同的模式上。
这看起来就像是一个固定窗口在 $s$ 上滑动,探寻窗口对应位置的字符与 $t$ 上对应字符是否相等。
def solve(s: str, t: str) -> int:
n = len(s)
m = len(t)
if n < m:
return 0
rec = [0] * (m+1)
rec[0] = 1
for i in range(n-m+1):
for j in range(m):
if s[i+j] == t[j]:
rec[j+1] += rec[j]
return rec[-1]
Tips:
- 由于在外循环固定减少了 $m$ 长度的轮次,因此是相比上述更快的一个实现;
- 同时快速失败地优化也被省略掉了,因为已经不能提升它的表现了
运行时间:44 ms (beats 99.44%, ~100%),内存占用:16.29 MB (beats 99.85%, ~100%) 。