首页 > 代码库 > poj 3415 后缀数组分组+排序+并查集

poj 3415 后缀数组分组+排序+并查集

Source Code

Problem: 3415 User: wangyucheng
Memory: 16492K Time: 704MS
Language: C++ Result: Accepted
    • Source Code
#include<iostream>#include<cstdio>#include<algorithm>#include<cstring>using namespace std;#define N 510000typedef long long ll;int wa[N],wb[N],sa[N],wv[N],ss[N],a[N];int n;int a1,a2;int cmp(int *r,int x,int y,int k){   return r[x]==r[y]&&r[x+k]==r[y+k];	}structP{   int x,y,z;   P(int a=0,int b=0){	   x=a,y=b;		}		bool operator<(P a)const{	   return x>a.x;	}}b[N];int b1;void da(int *r,int m){    int p,i,j,*x=wa,*y=wb;	for(i=0;i<n;i++)r[i]++;	r[n++]=0;	for(i=0;i<m;i++)ss[i]=0;	for(i=0;i<n;i++)ss[x[i]=r[i]]++;	for(i=1;i<m;i++)ss[i]+=ss[i-1];	for(i=n-1;i>=0;i--)sa[--ss[x[i]]]=i;	for(p=0,j=1;p<n;j<<=1,m=p){	   for(p=0,i=n-j;i<n;i++)y[p++]=i;	   for(i=0;i<n;i++)if(sa[i]>=j)y[p++]=sa[i]-j;	   for(i=0;i<m;i++)ss[i]=0;	   for(i=0;i<n;i++)wv[i]=x[y[i]];	   for(i=0;i<n;i++)ss[wv[i]]++;	   for(i=1;i<m;i++)ss[i]+=ss[i-1];	   for(i=n-1;i>=0;i--)sa[--ss[wv[i]]]=y[i];	   	for(swap(x,y),x[sa[0]]=0,p=1,i=1;i<n;i++)		x[sa[i]]=cmp(y,sa[i],sa[i-1],j)?p-1:p++;	}	}int rank[N],he[N];void ma(int *r){	int i,k=0;	for(i=0;i<n;i++)rank[sa[i]]=i;	for(i=0;i<n-1;i++){	 	for(k?k--:0;r[sa[rank[i]-1]+k]==r[i+k];k++);		he[rank[i]]=k;	}}char in[N];int K;int f[N];int get(int x){	return f[x]==x?x:f[x]=get(f[x]);}ll ans;ll s[N][2];void he1(int x,int y,ll &z){	int c=get(x);	int d=get(y);	z+=s[c][0]*s[d][1]+s[d][0]*s[c][1];	f[c]=d;	s[d][0]+=s[c][0];	s[d][1]+=s[c][1];}void solv(int l,int r){	 int i,j,y;	if(a[sa[l]]-1==‘#‘)return;	b1=0;	for(i=l;i<=r;i++){		s[i][0]=s[i][1]=0;	    if(sa[i]<a1)y=1;		else y=0;		s[i][y]++;		if(i==l)b[++b1]=P(n+1,i);		else b[++b1]=P(he[i]-K+1,i);	}	sort(b+1,b+b1+1);	for(i=l;i<=r;i++)f[i]=i;	int la;	ll tot=0;	la=0;	b[b1+1].x=0;	for(i=2;i<=b1+1;i++){	    if(i==b1+1||b[i].x!=b[i-1].x){		   for(j=la+1;j<i;j++){			   if(b[j].y>l)he1(b[j].y-1,b[j].y,tot);			}		   la=i-1;		   ans+=tot*(ll)(b[i-1].x-b[i].x);		}	}}int main(){	while(scanf("%d",&K),K){		ans=0;	   scanf("%s",in);	   a1=strlen(in);	   int i,j;	   for(i=0;i<a1;i++)a[i]=in[i];	   a[a1]=‘#‘;		   scanf("%s",in);	   a2=strlen(in);	   n=a1+a2+1;	   for(i=a1+1;i<n;i++)a[i]=in[i-a1-1];	   da(a,300);	   ma(a);		int la=0;		for(i=1;i<=n;i++){		    if(he[i]<K||i==n){			    solv(la,i-1);				la=i;				}			}		printf("%lld\n",ans);	}		}