Cancel

0005 - Longest Palindromic Substring

Other

·

July 05, 2023

题干

问题描述

破题

选取最长的回文子串,如果有多个,任选一个即可。参考专门的回文章节,这里只列出 Python 实现。

源代码

解①朴素实现:

"""O(n^2) 448ms, beats 92.33%"""

def solve(s: str) -> str:
    n = len(s)

    # Odd symmetric

    ans_odd_r = 0
    ans_odd_i = 0

    for i in range(1, n-1):
        max_len = min(i, n-1-i)

        r = 0

        for l in range(1, max_len + 1):
            if s[i-l] == s[i+l]:
                r += 1
            else:
                break

        if r > ans_odd_r:
            ans_odd_r = r
            ans_odd_i = i

    # Even symmetric

    ans_even_r = 0
    ans_even_i = 0

    for i in range(0, n-1):
        max_len = min(i+1, n-1-i)

        r = 0

        for l in range(1, max_len+1):
            if s[i -(l - 1)] == s[i + l]:
                r += 1
            else:
                break

        if r > ans_even_r:
            ans_even_r = r
            ans_even_i = i

    if ans_odd_r >= ans_even_r:
        return s[ans_odd_i - ans_odd_r : ans_odd_i + ans_odd_r + 1]
    else:
        return s[ans_even_i - (ans_even_r - 1) : ans_even_i + ans_even_r + 1]

解②串哈希:

""" O(nlogn) 1072ms beats 42.67% """

from typing import List
from math import log2


def solve(s: str) -> str:
    n = len(s)

    forward_hash = PrefixHash(s)
    backward_hash = PrefixHash(''.join(reversed(s)))

    # Odd

    odd_r = 1
    odd_i = 0

    for i in range(0, n):
        max_r = min(i+1, n-i)

        if max_r <= odd_r:
            continue

        acc_r = 0

        for k in reversed(range(0, int(log2(max_r)) + 1)):
            r = acc_r + 2 ** k

            if r > max_r:
                continue

            if forward_hash.query(i, i+r-1) == backward_hash.query(n-1-i, n-1-i+r-1):
                acc_r = r

        if acc_r > odd_r:
            odd_r = acc_r
            odd_i = i

    # Even

    even_r = 0
    even_i = 1

    for i in range(1, n):
        max_r = min(i, n-i)

        if max_r <= even_r:
            continue

        acc_r = 0

        for k in reversed(range(0, int(log2(max_r)) + 1)):
            r = acc_r + 2 ** k

            if r > max_r:
                continue

            if forward_hash.query(i, i+r-1) == backward_hash.query(n-i, n-i+r-1):
                acc_r = r

        if acc_r > even_r:
            even_r = acc_r
            even_i = i

    if even_r >= odd_r:
        return s[even_i - even_r : even_i+even_r]
    else:
        return s[odd_i - odd_r+1: odd_i+odd_r]


# 对于只数字和英文字符, p=79

P = 79
M = 10 ** 9
N = 1000  # Max String Length

NPOWS = [1] * N

for i in range(1, N):
    NPOWS[i] = NPOWS[i-1] * P % M

# it's nosense both considering correctness and performance
# def rank(c: str):
#     if '0' <= 'c' <= '9':
#         return ord(c) - ord('0') + 1
#     elif 'A' <= 'c' <= 'Z':
#         return ord(c) - ord('A') + 10 + 1
#     else:
#         return ord(c) - ord('a') + 26 + 10 + 1

class PrefixHash:
    def __init__(self, s: str) -> None:
        self.prefix = [0] * len(s)

        self.build_prefix(self.prefix, s)

    @staticmethod
    def build_prefix(l: List[int], s: str):
        if not s:
            return

        l[0] = ord(s[0])

        for i in range(1, len(s)):
            l[i] = (l[i-1] * P % M + ord(s[i])) % M

    def query(self, l: int, r: int) -> int:
        """ [l, r] """

        a = self.prefix[r]
        b = self.prefix[l-1] * NPOWS[r-(l-1)] % M if l > 0 else 0

        return (a + M - b) % M

解③哈希DP:

""" O(n) 251ms beats 95.4% """

from typing import List

def solve(s: str) -> str:
    n = len(s)

    forward_hash = PrefixHash(s)
    backward_hash = PrefixHash(''.join(reversed(s)))

    max_d = 1
    max_i = 0

    prev_d = 1

    for i in range(1, n):
        for d in reversed(range(1, min(i+1, prev_d+2)+1)):
            if forward_hash.query(i-(d-1), i) == backward_hash.query(n-1-i, n-1-i+(d-1)):
                prev_d = d
                break

        if prev_d > max_d:
            max_d = prev_d
            max_i = i


    return s[max_i-(max_d-1): max_i+1]

解④manacher:

""" O(n) 99ms, beats 98.74% """

from typing import List


def solve(s: str) -> str:
    d1 = compute_d1(s)
    d2 = compute_d2(s)

    (odd_i, odd_r) = map_d(d1)
    (even_i, even_r) = map_d(d2)

    if even_r >= odd_r:
        return s[even_i-even_r : even_i+even_r]
    else:
        return s[odd_i-odd_r+1 : odd_i+odd_r]


map_d = lambda l: max(enumerate(l), key=lambda x: x[1])


def compute_d1(s: str) -> List[int]:
    """ return (i, r) """

    n = len(s)

    d1 = [1] * n

    rl = 0  # rightmost left
    rr = 0  # rightmost right

    for i in range(1, n-1):
        if i < rr:
            j = rr - i + rl

            if d1[j] < j - rl + 1:
                d1[i] = d1[j]
                continue
            else:
                r = j - rl + 1
        else:
            r = 1

        while i+r-1 < n-1 and i-(r-1) > 0 and s[i+r] == s[i-r]:
            r += 1

        if i+r-1 > rr:
            rr = i+r-1
            rl = i-(r-1)

        d1[i] = r

    return d1


def compute_d2(s: str) -> List[int]:
    """ return (i, r) """

    n = len(s)

    d2 = [0] * n

    rl = 0  # rightmost left
    rr = 0  # rightmost right

    for i in range(1, n):
        if i < rr:
            j = rr - i + rl + 1

            if d2[j] < j - rl:
                d2[i] = d2[j]
                continue
            else:
                r = j - rl
        else:
            r = 0

        while i+r-1 < n-1 and i-r > 0 and s[i-r-1] == s[i+r]:
            r += 1

        if i+r-1 > rr:
            rr = i+r-1
            rl = i-r

        d2[i] = r

    return d2

解⑤manacher union:

""" O(n) 138ms, beats 96.32%"""

from typing import List

def solve(s: str) -> str:
    if len(s) == 1:
        return s

    s1 = '#'.join(s)
    d1_1 = compute_d1(s1)

    (odd_i1, odd_r1) = max(
        filter(lambda x: x[0] % 2 == 0, enumerate(d1_1)),
        key=lambda x: x[1])

    (even_i1, even_r1) = max(
        filter(lambda x: x[0] % 2 > 0, enumerate(d1_1)),
        key=lambda x: x[1])

    odd_i = odd_i1 // 2
    odd_r = (odd_r1+1) // 2

    even_i = (even_i1+1) // 2
    even_r = even_r1 // 2

    if even_r >= odd_r:
        return s[even_i-even_r: even_i+even_r]
    else:
        return s[odd_i-odd_r+1: odd_i+odd_r]

注解

© minghu6

·

theme

Simplex theme logo

by golas