首页 > 代码库 > POJ 1743 后缀数组:求最长不重叠子串

POJ 1743 后缀数组:求最长不重叠子串

数据:这题弄了好久,WA了数十发,现在还有个例子没过,可却A了,POJ 的数组也太弱了。

10
1 1 1 1 1 1 1 1 1 1

这组数据如果没有那个n-1<10判断的话,输入的竟然是5,我靠……

思路:这个题目关键的地方有两个:第一,重复的子串一定可以看作是某两个后缀的公共前缀,第二,把题目转化成去判定对于任意的一个长度k,是否存在长度至少为k的不重叠的重复的子串。

    转化成判定问题之后,就可以二分去解答了。在验证判定是否正确时,我们可以把相邻的所有不小于k的height[]看成一组,然后去看每个组内sa[]值的最大值与最小值之差是否满足大于等于k,如果有某一组满足这个条件,那么在这个组内就一定可以找到长度不小于k且不重叠的重复的子串。

    因为排名第i的字符串和排名第j的字符串的最长公共前缀等于height[i],height[i+1],...,height[j]中的最小值,所以把所有不小于k的height[]看成一组就保证了组内任意两个字符串的最长公共前缀都至少为k。其中最有可能形成不重叠的重复的子串就是组内sa[]值最大的字符串与sa[]值最小的字符串。

解法一:

按照分组的想法:AC代码如下,但是自己的样例没过就A了,原因是分组的时候我下标取的是sa[i-1]才过的但是11个1的样例就不行:

#include<iostream>
#include<cstdio>
#include<cstring>
#include<algorithm>
#include<map>
#include<queue>
#include<set>
#include<cmath>
#include<bitset>
#define mem(a,b) memset(a,b,sizeof(a))
#define lson i<<1,l,mid
#define rson i<<1|1,mid+1,r
#define llson j<<1,l,mid
#define rrson j<<1|1,mid+1,r
#define INF 0x7fffffff
#define maxn 20010
using namespace std;
typedef long long ll;
typedef unsigned long long ull;
void radix(int *str,int *a,int *b,int n,int m)
{
    static int count[maxn*2];
    mem(count,0);
    for(int i=0;i<n;i++) ++count[str[a[i]]];
    for(int i=1;i<=m;i++) count[i]+=count[i-1];
    for(int i=n-1;i>=0;i--) b[--count[str[a[i]]]]=a[i];
}
void suffix(int *str,int *sa,int n,int m) //倍增算法计算出后缀数组sa
{
    static int rank[maxn*2],a[maxn*2],b[maxn*2];
    for(int i=0;i<n;i++) rank[i]=i;
    radix(str,rank,sa,n,m);
    rank[sa[0]]=0;
    for(int i=1;i<n;i++)
        rank[sa[i]]=rank[sa[i-1]]+(str[sa[i]]!=str[sa[i-1]]);
    for(int i=0;1<<i<n;i++)
    {
        for(int j=0;j<n;j++)
        {
            a[j]=rank[j]+1;
            b[j]=j+(1<<i)>=n?0:rank[j+(1<<i)]+1;
            sa[j]=j;
        }
        radix(b,sa,rank,n,n);
        radix(a,rank,sa,n,n);
        rank[sa[0]]=0;
        for(int j=1;j<n;j++)
            rank[sa[j]]=rank[sa[j-1]]+(a[sa[j-1]]!=a[sa[j]]||b[sa[j-1]]!=b[sa[j]]);
    }
}
void calcHeight(int *str,int *sa,int *h,int n) //求出最长公共前缀数组h
{
    static int rank[maxn*2];
    int k=0;
    h[0]=0;
    for(int i=0;i<n;i++) rank[sa[i]]=i;
    for(int i=0;i<n;i++)
    {
        k=k==0?0:k-1;
        if(rank[i])
            while(str[i+k]==str[sa[rank[i]-1]+k]) k++;
        else k=0;
        h[rank[i]]=k;
    }
}
int a[maxn],sa[maxn],height[maxn];
bool binary(int mid,int n)
{
    int i=1;
    while(1)
    {
        while(i<n&&height[i]<mid) i++;
        if(i==n) break;
        int Max=-INF,Min=INF;
        while(i<n&&height[i]>=mid) //按照分组的思想,比目标值大的在同一组
        {
            Max=max(sa[i],Max);
            Min=min(sa[i],Min);
            Max=max(sa[i-1],Max); //因为height[i]是rank[sa[i-1]]和rank[sa[i]]的最长公共前缀,所以这样取下标
            Min=min(sa[i-1],Min); //但是在11个1的样例中输出6是错的,因为取了sa[i-1]的下标错的,但是少了题目样例又不得过,实在不知道怎么避免了……^-^
            if(Max-Min>=mid) return true;
            i++;
        }
    }
    return false;
}
int main()
{
    //freopen("test.txt","r",stdin);
    int n,i;
    while(~scanf("%d",&n)&&n)
    {
        for(i=0;i<n;i++)
            scanf("%d",a+i);
        if(n-1<10)
        {
            puts("0");
            continue;
        }
        for(i=0;i<n-1;i++)
            a[i]=a[i+1]-a[i]+90;
        a[--n]=0;
        mem(sa,0),mem(height,0);
        suffix(a,sa,n,200);
        calcHeight(a,sa,height,n);
        //for(i=1;i<n;i++)
        //    cout<<height[i]<<' '<<sa[i]<<endl;
        int l=1,r=n,mid,ans=0;
        while(l<=r)
        {
            mid=(l+r)>>1;
            if(binary(mid,n)) ans=mid,l=mid+1;
            else r=mid-1;
        }
        if(ans<4) puts("0");
        else printf("%d\n",ans+1);
    }
    return 0;
}
/*
8
1 1 2 1 1 1 1 2
10
1 1 1 1 1 1 1 1 1 1
11
1 1 1 1 1 1 1 1 1 1 1
*/


解法二:

这是网上的解决方法,和前i个下标的最大最小比较,这个我有点不太理解,原理好像也是按照分组的思想好像吧,但是>=mid的和<mid的下标比较好像又不是分组的思想,实在让我有点不理解了……请高人指点一二啊!!!

还有个疑问:

问题是suffix(a,sa,n+1,200)这个函数传进去的数组长度为n+1,但是我改成n的时候也是对的,虽然说POJ 的数组很弱嘛,但是也没有弱到这样去吧,这个数组的长度刚开始为什么要+1呢,本来所求的后缀就是两个相差的值呀,我写n也对啊,网上的代码都说要+1才行,但是我测试没+1也过了,奇怪了?

先做后缀数组的其他题,等熟练运用后缀数组了再回来看看这题的代码改一下还有解决上面这个问题吧。我还是用自己分组的思想像上面那个代码解决问题比较好理解。

#include<iostream>
#include<cstdio>
#include<cstring>
#include<algorithm>
#include<map>
#include<queue>
#include<set>
#include<cmath>
#include<bitset>
#define mem(a,b) memset(a,b,sizeof(a))
#define lson i<<1,l,mid
#define rson i<<1|1,mid+1,r
#define llson j<<1,l,mid
#define rrson j<<1|1,mid+1,r
#define INF 0x7fffffff
#define maxn 20010
using namespace std;
typedef long long ll;
typedef unsigned long long ull;
void radix(int *str,int *a,int *b,int n,int m)
{
    static int count[maxn*2];
    mem(count,0);
    for(int i=0;i<n;i++) ++count[str[a[i]]];
    for(int i=1;i<=m;i++) count[i]+=count[i-1];
    for(int i=n-1;i>=0;i--) b[--count[str[a[i]]]]=a[i];
}
void suffix(int *str,int *sa,int n,int m) //倍增算法计算出后缀数组sa
{
    static int rank[maxn*2],a[maxn*2],b[maxn*2];
    for(int i=0;i<n;i++) rank[i]=i;
    radix(str,rank,sa,n,m);
    rank[sa[0]]=0;
    for(int i=1;i<n;i++)
        rank[sa[i]]=rank[sa[i-1]]+(str[sa[i]]!=str[sa[i-1]]);
    for(int i=0;1<<i<n;i++)
    {
        for(int j=0;j<n;j++)
        {
            a[j]=rank[j]+1;
            b[j]=j+(1<<i)>=n?0:rank[j+(1<<i)]+1;
            sa[j]=j;
        }
        radix(b,sa,rank,n,n);
        radix(a,rank,sa,n,n);
        rank[sa[0]]=0;
        for(int j=1;j<n;j++)
            rank[sa[j]]=rank[sa[j-1]]+(a[sa[j-1]]!=a[sa[j]]||b[sa[j-1]]!=b[sa[j]]);
    }
}
void calcHeight(int *str,int *sa,int *h,int n) //求出最长公共前缀数组h
{
    static int rank[maxn*2];
    int k=0;
    h[0]=0;
    for(int i=0;i<n;i++) rank[sa[i]]=i;
    for(int i=0;i<n;i++)
    {
        k=k==0?0:k-1;
        if(rank[i])
            while(str[i+k]==str[sa[rank[i]-1]+k]) k++;
        else k=0;
        h[rank[i]]=k;
    }
}
int a[maxn],sa[maxn],height[maxn];
bool binary(int mid,int n)
{
    int Max=-INF,Min=INF;
    for(int i=1;i<n;i++)
    {
        if(height[i]<mid) Max=Min=sa[i];
        else
        {
            if(sa[i]<Min) Min=sa[i];//在同一组里寻找最大与最小的比较差值
            if(sa[i]>Max) Max=sa[i];
            if(Max-Min>=mid) return true;
        }
    }
    return false;
}
int main()
{
    //freopen("test.txt","r",stdin);
    int n,i;
    while(~scanf("%d",&n)&&n)
    {
        for(i=0;i<n;i++)
            scanf("%d",a+i);
        if(n-1<10)
        {
            puts("0");
            continue;
        }
        for(i=0;i<n-1;i++)
            a[i]=a[i+1]-a[i]+90;
        a[--n]=0;
        mem(sa,0),mem(height,0);
        suffix(a,sa,n+1,200);
        calcHeight(a,sa,height,n);
        for(int i=1;i<n;i++)
            cout<<height[i]<<' '<<sa[i]<<' '<<a[i]<<endl;
        int l=1,r=n,mid,ans=0;
        while(l<=r)
        {
            mid=(l+r)>>1;
            if(binary(mid,n)) ans=mid,l=mid+1;
            else r=mid-1;
        }
        if(ans<4) puts("0");
        else printf("%d\n",ans+1);
    }
    return 0;
}
/*
8
1 1 2 1 1 1 1 2
10
1 1 1 1 1 1 1 1 1 1
11
1 1 1 1 1 1 1 1 1 1 1
*/