1458 - Max Dot Product of Two Subsequences
题干
破题
一般地两串上的DP是从两个前缀子串开始的更小规模的问题扩展到全串,从而直接解决题目要求的问题,而本题特殊的地方是子序列并不是完整的前缀子串。
不妨把前缀子串视为 border,而譬如子序列则是 content ,它和 border之间有 padding ,其他普通的问题可以通过 border 就可以直接计算出来,而对于子序列这样的问题,需要额外考虑是否需要“填充” padding。
“padding” 的引入让问题稍微复杂了一点,但主要还是干扰了我们破题的思路,如果我们能把保持一个清晰的思路,解题还是很容易的。
第一版:
因为乘积存在负数,并且不允许空的点积序列,因此先排除乘积全是负数的情况。
之后在 DP 过程中只需要考虑是否把非负的乘积加进来即可,这也暗示了“padding”的判断方法:只有正的乘积会被加进来,负的会被视为“padding”跳过。
def solve(nums1: List[int], nums2: List[int]) -> int:
n = len(nums1)
m = len(nums2)
if nums1[0] * nums2[0] < 0\
and all(map(lambda x: x * nums1[0] > 0, nums1[1:]))\
and all(map(lambda x: x * nums2[0] > 0, nums2[1:])):
return -min(map(lambda x:abs(x), nums1)) * min(map(lambda x:abs(x), nums2))
cache = [[0] * m for _ in range(n)]
cache[0][0] = max(0, nums1[0]*nums2[0])
for j in range(1, m):
cache[0][j] = max(nums1[0]*nums2[j], cache[0][j-1])
for i in range(1, n):
cache[i][0] = max(nums1[i]*nums2[0], cache[i-1][0])
for j in range(1, m):
cache[i][j] = max(
cache[i-1][j],
cache[i][j-1]
)
if nums1[i]*nums2[j]+cache[i-1][j-1] > cache[i][j] :
cache[i][j] = nums1[i]*nums2[j]+cache[i-1][j-1]
return cache[n-1][m-1]
282 ms (beats 82.09%), 19.02 MB (beats 66.24%)
改进:
- 采用了一个稍微更好一点的对于乘积全是负数的情况的判断方法,取两个极值进行比较,但是因此要分两种情况进行判断;
- 按照前面滚动数组的思路讲空间占用压缩到一维;
- 合并了 DP 核心算法里的分支判断
def solve(nums1: List[int], nums2: List[int]) -> int:
if len(nums1) < len(nums2):
nums1, nums2 = nums2, nums1
n = len(nums1)
m = len(nums2)
nums1_max = max(nums1)
nums2_min = min(nums2)
if nums1_max < 0 and nums2_min > 0:
return nums1_max * nums2_min
nums1_min = min(nums1)
nums2_max = max(nums2)
if nums1_min > 0 and nums2_max < 0:
return nums1_min * nums2_max
cache = [0] * m
for i in range(n):
pre1 = cache[0]
cache[0] = max(nums1[i]*nums2[0], cache[0])
for j in range(1, m):
pre0 = cache[j]
cache[j] = max(
pre0,
cache[j-1],
nums1[i]*nums2[j]+pre1,
nums1[i]*nums2[j]
)
pre1 = pre0
return cache[m-1]
183 ms (beats 95.32%), 16.34 MB (beats 92.92%, ~100%)
另版:
还有一点稍稍不同思路的是把全负乘积考虑进一般情况,但是对于压缩的 DP 数组来说,需要额外处理初始值。
def solve(nums1: List[int], nums2: List[int]) -> int:
if len(nums1) < len(nums2):
nums1, nums2 = nums2, nums1
n = len(nums1)
m = len(nums2)
cache = [0] * m
cache[0] = nums1[0]*nums2[0]
for j in range(1, m):
cache[j] = max(nums1[0]*nums2[j], cache[j-1])
for i in range(1, n):
pre1 = cache[0]
cache[0] = max(nums1[i]*nums2[0], cache[0])
for j in range(1, m):
pre0 = cache[j]
cache[j] = max(
pre0,
cache[j-1],
nums1[i]*nums2[j],
nums1[i]*nums2[j]+pre1,
)
pre1 = pre0
return cache[m-1]
189 ms (beats 95.32%), 16.29 MB (beats 97.06%, ~100%)