首页 > 代码库 > hdu 4029 Distinct Sub-matrix (后缀数组)

hdu 4029 Distinct Sub-matrix (后缀数组)

题目大意:

n*m的矩阵中,有多少个子矩阵不是同的。


思路分析:

假设这题题目只是一维的求一个串中有多少个子串是不同的。

那么也就是直接扫描height,然后减去前缀。


现在变成二维,如何降低维度。

知道hash 的作用就是将一个串映射到一个数字。

那我们就将这个矩阵hash,考虑到不同的长度和宽度都会导致不同,

所以就要枚举子矩阵的宽度。

hash [i][j] 就表示在当前宽度W 下,从 第 i 行 第 j 个开始往后W长度的串的hash值。

然后将列上相同起点的hash值 子串。

然后将所有的子串组合成 要跑后缀数组的串。


后缀数组之后就和一维的处理方式一样了。


#include <cstdio>
#include <iostream>
#include <algorithm>
#include <cstring>
#include <map>
#define maxn 100005

using namespace std;
typedef unsigned long long ull;
const int base = 103;
int str[maxn];
int sa[maxn],t1[maxn],t2[maxn],c[maxn],n;

void suffix(int m)
{
    int *x=t1,*y=t2;
    for(int i=0; i<m; i++)c[i]=0;
    for(int i=0; i<n; i++)c[x[i]=str[i]]++;
    for(int i=1; i<m; i++)c[i]+=c[i-1];
    for(int i=n-1; i>=0; i--)sa[--c[x[i]]]=i;
    for(int k=1; k<=n; k<<=1)
    {
        int p=0;
        for(int i=n-k; i<n; i++)y[p++]=i;
        for(int i=0; i<n; i++)if(sa[i]>=k)y[p++]=sa[i]-k;
        for(int i=0; i<m; i++)c[i]=0;
        for(int i=0; i<n; i++)c[x[y[i]]]++;
        for(int i=0; i<m; i++)c[i]+=c[i-1];
        for(int i=n-1; i>=0; i--)sa[--c[x[y[i]]]]=y[i];
        swap(x,y);
        p=1;
        x[sa[0]]=0;
        for(int i=1; i<n; i++)
            x[sa[i]]=y[sa[i-1]]==y[sa[i]]&&y[sa[i-1]+k]==y[sa[i]+k]?p-1:p++;
        if(p>=n)break;
        m=p;
    }
}
int rank[maxn],height[maxn];
void getheight()
{
    int k=0;
    for(int i=0; i<n; i++)rank[sa[i]]=i;
    for(int i=0; i<n; i++)
    {
        if(k)k--;
        if(!rank[i])continue;
        int j=sa[rank[i]-1];
        while(str[i+k]==str[j+k])k++;
        height[rank[i]]=k;
    }
}

char ch[200][200];
ull hash[200][200];
map <ull,int>cq;

int main()
{
    int T;
    scanf("%d",&T);
    for(int cas=1;cas<=T;cas++)
    {

        int N,M;
        scanf("%d%d",&N,&M);
        for(int i=0;i<N;i++)
            scanf("%s",ch[i]);

        ull ans=0;
        memset(hash,0,sizeof hash);
        for(int w=1;w<=M;w++)
        {
            int tot=0;
            cq.clear();
            for(int i=0;i<N;i++)
            for(int j=0;j+w-1<M;j++){
                hash[i][j]=hash[i][j]*base+ch[i][j+w-1]-'A';
                if(!cq[hash[i][j]])cq[hash[i][j]]=++tot;
            }

            int cnt=0;
            for(int j=0;j+w-1<M;j++){
                for(int i=0;i<N;i++)
                {
                    str[cnt++]=cq[hash[i][j]];
                }
                str[cnt++]=++tot;
            }
            str[cnt-1]=0;
            n=cnt;
            suffix(tot);
            getheight();
            ull tmp = (N*(N+1)/2)*(M-w+1);

            for(int i=1;i<cnt;i++){
                tmp-=height[i];
            }
            ans+=tmp;
        }
        printf("Case #%d: %I64d\n",cas,ans);
    }
    return 0;
}