首页 > 代码库 > 回忆树

回忆树

先上代码

代码

#include <bits/stdc++.h>
/*
a.长度不一定刚好是len*2-2,所以要计算kmp的长度 
b.忘了返回值 
*/
using namespace std;

const int N = 101000,LOGN = 20,M = 301000;

char read(){
    char ch = getchar();
    while (ch < a || z < ch) ch = getchar();
    return ch;
}
struct Edge{
    int next,end;
}edge[LOGN*N];
struct Edge2{
    int next,end;char ch;
}edge2[N<<1];
struct Node{
    int cnt,son[2];
}nod[N*LOGN];


char ch[N],s[M],d[M];
int efn,fa[N][LOGN],first[N][LOGN];//倍增 
int first2[N],efn2;
int sa[N],rank[N];//后缀数组 sa 排名为i的是 rank->i的排名 
int dep[N],siz[N],hson[N],fat[N],top[N];//树剖 
int next[N];//kmp
int root[N],Len;
int n,m,o,lg2[N],len;
void init();
void addedge(int,int,int);
void addedge(int,int,char);
void dfs(int,int);
void dfs3(int,int);
void dfs2(int,int);
void insert(int,int,int,int,int);
void build(int,int,int);
int lca(int,int);
int ef1(int,int);
int ef2(int,int);
int getans(int,int,int,int,int,int);
bool check1(int);
bool check2(int);
int main(){
    scanf("%d%d",&n,&m);    
    lg2[0] = -1;
    for (int i = 1;i <= n;i++) lg2[i] = lg2[i/2]+1;
    o = lg2[n]+1;
    
    for (int i = 1;i < n;i++){
        int x,y;char ch;
        scanf("%d%d",&x,&y);
        ch = read();
        addedge(x,y,ch);
    }
    dfs(1,0);
    top[1] = 1;dfs3(1,0);
    
    for (int i = 1;i <= n;i++) if (fa[i][0]) addedge(fa[i][0],i,0);
    
    for (int i = 1;i <= o;i++){
        for (int j = 1;j <= n;j++){
            fa[j][i] = fa[fa[j][i-1]][i-1];
            if (fa[j][i]) addedge(fa[j][i],j,i);
        }
    }
    init();
    len = 1;root[0] = 1;build(1,1,n);
    dfs2(1,0);
    ch[1] =  ;
    
    for (int i = 1;i <= m;i++){
        int x,y,z,u,v;
        int l,r,ans = 0,len,cntt = 0;
        
        scanf("%d%d",&x,&y);z = lca(x,y);
        scanf("%s",s);len = strlen(s);Len = len;
        u = x;v = y;
        for (int j = 16;j >= 0;j--){
            if (dep[fa[u][j]] >= dep[z]+len-1) u = fa[u][j];
            if (dep[fa[v][j]] >= dep[z]+len-1) v = fa[v][j];
        }    
        cntt = dep[u]-dep[z] + dep[v] - dep[z];//a 
        l = ef1(2,n+1);
        r = ef2(1,n);
        if (l <= r)ans += getans(root[u],root[x],1,n,min(n,l),max(2,r));
        reverse(s+0,s+len);
        l = ef1(2,n+1);
        r = ef2(1,n);
        if (l <= r)ans += getans(root[v],root[y],1,n,min(n,l),max(2,r));
        int t1 = 0,w1 = cntt-1;
        while (u != z){
            d[t1++] = ch[u];
            u = fat[u];
        }
        while (v != z){
            d[w1--] = ch[v];
            v = fat[v];
        }    
        reverse(s+0,s+len);
        w1 = cntt;
        next[0] = -1;t1 = -1;
        for (int j = 1;j < len;j++){
            t1++;
            while (t1 && s[t1] != s[j]) t1 = next[t1];
            if (s[t1] == s[j]) next[j] = t1;
            else next[j] = --t1;
        }
        int now = -1;
        for (int j = 0;j < w1;j++){
            while (now != -1 && s[now+1] != d[j]) now = next[now];
            if (s[now+1] == d[j]) now++;
            if (now == len-1) {ans++;now = next[now];}
        }
        printf("%d\n",ans);
    }
    
    return 0;
}
void init(){
    static int x2[N],y2[N],a[N];
    int *x = x2,*y = y2,m = 256,cnt = -1;
    for (int i = 0;i <= m;i++) a[i] = 0;
    for (int i = 1;i <= n;i++) a[x[i] = ch[i]]++;
    for (int i = 1;i <= m;i++) a[i] += a[i-1];
    for (int i = 1;i <= n;i++) sa[a[ch[i]]--] = i;
    x[0] = -1;
    for (int k = 1;k <= n;k <<= 1){
        int p = 0;cnt++;
        for (int i = 0;i <= m;i++) a[i] = 0;
        for (int i = 1;i <= n;i++)
            if (fa[sa[i]][cnt] <= 1) y[++p] = sa[i];
        for (int i = 2;i <= n;i++)
            for (int h = first[sa[i]][cnt];h;h = edge[h].next){
                int u = edge[h].end;
                y[++p] = u;
            }
        for (int i = 1;i <= n;i++) a[x[y[i]]]++;
        for (int i = 1;i <= m;i++) a[i] += a[i-1];
        for (int i = n;i >= 1;i--) sa[a[x[y[i]]]--] = y[i];
        swap(x,y);
        p = 1;
        x[sa[1]] = p;
        for (int i = 2;i <= n;i++){
            int u = fa[sa[i]][cnt],v = fa[sa[i-1]][cnt];
            if (u == 1) u = 0;if (v == 1) v = 0;
            x[sa[i]] = (y[sa[i]] == y[sa[i-1]]) ? (u == 0 && v == 0 ? p : y[u] == y[v] ? p : ++p) : ++p;        
        }
        m = p;
        if (m >= n) break;
    }
    for (int i = 1;i <= n;i++) rank[sa[i]] = i;
}
void addedge(int x,int y,int z){
    edge[++efn].end = y;
    edge[  efn].next = first[x][z];
    first[x][z] = efn;
}
void dfs(int x,int y){
    fa[x][0] = y;
    siz[x] = 1;fat[x] = y;dep[x] = dep[y]+1;
    for (int h = first2[x];h;h = edge2[h].next){
        int u = edge2[h].end;
        if (u != y) {
            ch[u] = edge2[h].ch;
            dfs(u,x);
            siz[x] += siz[u];
            hson[x] = siz[u] > siz[hson[x]] ? u : hson[x];
        }
    }
}
void dfs3(int x,int y){
    if (hson[x]){
        top[hson[x]] = top[x];
        dfs3(hson[x],x);
    }
    for (int h = first2[x];h;h = edge2[h].next){
        int u = edge2[h].end;
        if (u != y && u != hson[x]){
            top[u] = u;
            dfs3(u,x);
        }
    }
}
void dfs2(int x,int y){
    root[x] = ++len;
    insert(root[y],root[x],1,n,rank[x]);
    for (int h = first[x][0];h;h = edge[h].next){
        int u = edge[h].end;
        dfs2(u,x);
    }
}
void build(int p,int l,int r){
    if (l == r) return;
    int mid = l + r >> 1;
    nod[p].son[0] = ++len;
    nod[p].son[1] = ++len;
    build(nod[p].son[0],l,mid);
    build(nod[p].son[1],mid+1,r);
}
void insert(int p,int q,int l,int r,int x){
    nod[q].cnt = nod[p].cnt+1;
    if (l == r) return;
    int mid = l + r >> 1;
    if (x <= mid) {
        nod[q].son[0] = ++len;
        nod[q].son[1] = nod[p].son[1];
        insert(nod[p].son[0],nod[q].son[0],l,mid,x);
    }
    else{
        nod[q].son[1] = ++len;
        nod[q].son[0] = nod[p].son[0];
        insert(nod[p].son[1],nod[q].son[1],mid+1,r,x);
    }
}
int lca(int x,int y){
    while (top[x] != top[y]){
        if (dep[top[x]] < dep[top[y]]) y = fat[top[y]];
        else x = fat[top[x]];
    }
    return dep[x] < dep[y] ? x : y;
}
void addedge(int x,int y,char ch){
    edge2[++efn2].end = y;
    edge2[  efn2].ch = ch;
    edge2[  efn2].next = first2[x];
    first2[x] = efn2;
    edge2[++efn2].end = x;
    edge2[  efn2].ch = ch;
    edge2[  efn2].next = first2[y];
    first2[y] = efn2;    
}
int ef1(int l,int r){
    int mid = l + r >> 1;
    while (l < r){
        if (check1(sa[mid])) r = mid;else l = mid+1;
        mid = l + r >> 1;
    }
    return l;
}
int ef2(int l,int r){
    int mid = l + r + 1 >> 1;
    while (l < r){
        if (check2(sa[mid])) l = mid;else r = mid-1;
        mid = l + r + 1 >> 1;
    }
    return l;    
}
bool check1(int p){
    for (int i = 0;i < Len;i++){
        if (p == 1) return 0;
        if (ch[p] < s[i]) return 0;
        if (ch[p] > s[i]) return 1;
        p = fat[p];
    }
    return 1;
}
bool check2(int p){
    for (int i = 0;i < Len;i++){
        if (p == 1) return 1;
        if (ch[p] < s[i]) return 1;
        if (ch[p] > s[i]) return 0;
        p = fat[p];
    }
    if (p == 1) return 1;
    return 1;//b
}
int getans(int p,int q,int l,int r,int x,int y){
    if (y < x) return 0;
    if (l == x && r == y) return nod[q].cnt - nod[p].cnt;
    int mid = l + r >> 1;
    if (y <= mid) return getans(nod[p].son[0],nod[q].son[0],l,mid,x,y);
    else if (mid < x) return getans(nod[p].son[1],nod[q].son[1],mid+1,r,x,y);
    else return getans(nod[p].son[0],nod[q].son[0],l,mid,x,mid)+getans(nod[p].son[1],nod[q].son[1],mid+1,r,mid+1,y);
}

 

回忆树