首页 > 代码库 > poj 3415 后缀数组分组+排序+并查集
poj 3415 后缀数组分组+排序+并查集
Source Code
Problem: 3415 | User: wangyucheng | |
Memory: 16492K | Time: 704MS | |
Language: C++ | Result: Accepted |
- Source Code
#include<iostream>#include<cstdio>#include<algorithm>#include<cstring>using namespace std;#define N 510000typedef long long ll;int wa[N],wb[N],sa[N],wv[N],ss[N],a[N];int n;int a1,a2;int cmp(int *r,int x,int y,int k){ return r[x]==r[y]&&r[x+k]==r[y+k]; }structP{ int x,y,z; P(int a=0,int b=0){ x=a,y=b; } bool operator<(P a)const{ return x>a.x; }}b[N];int b1;void da(int *r,int m){ int p,i,j,*x=wa,*y=wb; for(i=0;i<n;i++)r[i]++; r[n++]=0; for(i=0;i<m;i++)ss[i]=0; for(i=0;i<n;i++)ss[x[i]=r[i]]++; for(i=1;i<m;i++)ss[i]+=ss[i-1]; for(i=n-1;i>=0;i--)sa[--ss[x[i]]]=i; for(p=0,j=1;p<n;j<<=1,m=p){ for(p=0,i=n-j;i<n;i++)y[p++]=i; for(i=0;i<n;i++)if(sa[i]>=j)y[p++]=sa[i]-j; for(i=0;i<m;i++)ss[i]=0; for(i=0;i<n;i++)wv[i]=x[y[i]]; for(i=0;i<n;i++)ss[wv[i]]++; for(i=1;i<m;i++)ss[i]+=ss[i-1]; for(i=n-1;i>=0;i--)sa[--ss[wv[i]]]=y[i]; for(swap(x,y),x[sa[0]]=0,p=1,i=1;i<n;i++) x[sa[i]]=cmp(y,sa[i],sa[i-1],j)?p-1:p++; } }int rank[N],he[N];void ma(int *r){ int i,k=0; for(i=0;i<n;i++)rank[sa[i]]=i; for(i=0;i<n-1;i++){ for(k?k--:0;r[sa[rank[i]-1]+k]==r[i+k];k++); he[rank[i]]=k; }}char in[N];int K;int f[N];int get(int x){ return f[x]==x?x:f[x]=get(f[x]);}ll ans;ll s[N][2];void he1(int x,int y,ll &z){ int c=get(x); int d=get(y); z+=s[c][0]*s[d][1]+s[d][0]*s[c][1]; f[c]=d; s[d][0]+=s[c][0]; s[d][1]+=s[c][1];}void solv(int l,int r){ int i,j,y; if(a[sa[l]]-1==‘#‘)return; b1=0; for(i=l;i<=r;i++){ s[i][0]=s[i][1]=0; if(sa[i]<a1)y=1; else y=0; s[i][y]++; if(i==l)b[++b1]=P(n+1,i); else b[++b1]=P(he[i]-K+1,i); } sort(b+1,b+b1+1); for(i=l;i<=r;i++)f[i]=i; int la; ll tot=0; la=0; b[b1+1].x=0; for(i=2;i<=b1+1;i++){ if(i==b1+1||b[i].x!=b[i-1].x){ for(j=la+1;j<i;j++){ if(b[j].y>l)he1(b[j].y-1,b[j].y,tot); } la=i-1; ans+=tot*(ll)(b[i-1].x-b[i].x); } }}int main(){ while(scanf("%d",&K),K){ ans=0; scanf("%s",in); a1=strlen(in); int i,j; for(i=0;i<a1;i++)a[i]=in[i]; a[a1]=‘#‘; scanf("%s",in); a2=strlen(in); n=a1+a2+1; for(i=a1+1;i<n;i++)a[i]=in[i-a1-1]; da(a,300); ma(a); int la=0; for(i=1;i<=n;i++){ if(he[i]<K||i==n){ solv(la,i-1); la=i; } } printf("%lld\n",ans); } }
声明:以上内容来自用户投稿及互联网公开渠道收集整理发布,本网站不拥有所有权,未作人工编辑处理,也不承担相关法律责任,若内容有误或涉及侵权可进行投诉: 投诉/举报 工作人员会在5个工作日内联系你,一经查实,本站将立刻删除涉嫌侵权内容。