首页 > 代码库 > hdu 3068 最长回文(manacher算法)

hdu 3068 最长回文(manacher算法)

最长回文

                                                                        Time Limit: 4000/2000 MS (Java/Others)    Memory Limit: 32768/32768 K (Java/Others)

Problem Description
给出一个只由小写英文字符a,b,c...y,z组成的字符串S,求S中最长回文串的长度.
回文就是正反读都是一样的字符串,如aba, abba等
 

Input
输入有多组case,不超过120组,每组输入为一行小写英文字符a,b,c...y,z组成的字符串S
两组case之间由空行隔开(该空行不用处理)
字符串长度len <= 110000
 

Output
每一行一个整数x,对应一组case,表示该组case的字符串中所包含的最长回文长度.
 

Sample Input
aaaa abab
 

Sample Output
4 3
 
这个题我先用后缀数组+最长公共前缀做的,但是超时了(可能是我的代码写的太烂了)。后来搜题解,才发现大家都是用manacher算法做的,时间复杂度为O(n)。第一次听说这个算法,于是就学了一下。

定义数组p[i]表示以i为中心的(包含i这个字符)回文串半径长

将字符串s从前扫到后for(int i=0;i<strlen(s);++i)来计算p[i],则最大的p[i]就是最长回文串长度,则问题是如何去求p[i]?

由于s是从前扫到后的,所以需要计算p[i]时一定已经计算好了p[1]....p[i-1]

假设现在扫描到了i+k这个位置,现在需要计算p[i+k]

定义maxlen是i+k位置前所有回文串中能延伸到的最右端的位置,即maxlen=p[i]+i;//p[i]+i表示最大的

分两种情况:

1.i+k这个位置不在前面的任何回文串中,即i+k>maxlen,则初始化p[i+k]=1;//本身是回文串

然后p[i+k]左右延伸,即while(s[i+k+p[i+k]] == s[i+k-p[i+k]])++p[i+k]

2.i+k这个位置被前面以位置i为中心的回文串包含,即maxlen>i+k

这样的话p[i+k]就不是从1开始


由于回文串的性质,可知i+k这个位置关于i与i-k对称,

所以p[i+k]分为以下3种情况得出

//黑色是i的回文串范围,蓝色是i-k的回文串范围,





根据上面的算法可以得出:p[i]是以i为中心的回文串长度,那么对于aaaa这样的字符串求回文字符串时发现对称中心不是一个字符,而是空的,所以要把偶数字符串变成奇数字符串,方法就是在字符串中插入字符串中没有出现过的字符,例如‘#‘。

核心代码:

for(int i = 1; i < len; i++) {
    p[i] = mmax > i ? min(p[id*2-i], mmax - i) : 1;
    while(s[i+p[i]] == s[i-p[i]]) p[i]++;
    if(i + p[i] > id + p[id]) {
        id = i;
        mmax = i + p[i];
    }
}
最长回文长度就是mmax-1。

p[i]为回文半径,如果该半径以’#‘开始,即回文串为‘#‘‘s[i]‘#‘……‘#‘,则一定以‘#‘结束,所以mmax-1以后‘#‘和‘s[]‘一样多,即mmax-1是原串以i为中心的回文字符串长度。如果该半径是以s[]开始的,即‘s[]‘‘#‘……‘#‘‘s[]‘,则回文串长度是p[i]-1。


#include<cstdio>
#include<cstring>
#include<algorithm>
using namespace std;
const int N = 220005;
char str[N];
int p[N];

void manacher(char *s, int len)
{
    p[0] = 1;
    int mmax = 0, id = 0;
    for(int i = 1; i < len; i++) {
        p[i] = mmax > i ? min(p[id*2-i], mmax - i) : 1;
        while(s[i+p[i]] == s[i-p[i]]) p[i]++;
        if(i + p[i] > id + p[id]) {
            id = i;
            mmax = i + p[i];
        }
    }
}

int main()
{
    while(~scanf("%s",str)) {
        int len = strlen(str);
        for(int i = len; i >= 0; i--) {
            str[(i<<1) + 1] = '#';
            str[(i<<1) + 2] = str[i];
        }
        str[0] = '*'; //防止数组越界
        len = len * 2 + 2;
        manacher(str, len);
        int ans = 0;
        for(int i = 0; i < len; i++)
            ans = max(ans, p[i]-1);
        printf("%d\n", ans);
    }
    return 0;
}