首页 > 代码库 > hdu 4929 Another Letter Tree(LCA+DP)

hdu 4929 Another Letter Tree(LCA+DP)

hdu 4929 Another Letter Tree(LCA+DP)

题意:有一棵树n个节点(n<=50000),树上每个节点上有一个字母。m个询问(m<=50000),每次询问一个(a,b),问a节点到b节点的点不重复路径组成的字符串中子序列为s0的情况有多少种,s0长度小于等于30(注意s0是已经给定的,而不是每次询问都会给出一个新的)。

解法:一个很直观的想法,求出lca(设其为w)后,枚举x,求出a到w的路径上,能匹配s0的x长度前缀的情况有多少种,令其为c[x]。再求出b到w的路径上能匹配s0的L-x(L表示s0的长度)长度后缀的情况有多少种,令其为d[l-x],那么将所有的c[x]*d[l-x](x属于[0,l])加起来,即为答案(当然这里要考虑w这个点,不能同时出现在两部分当中,处理方法是w这个位置两部分都不要,然后在考虑w这个位置一定被选进去,两种情况加起来即可)。然后问题的难点在于,考虑某个节点u时,如何处理出c[i]与d[i]。这里,我们需要预处理一个dp数组,dp[i][j][u]表示,从u节点到根的路径匹配了s0[i,j]这段子串的子序列有多少种。那么c[i]就等于u到根的路径匹配了s0的i长度前缀情况数,减去有长度a的前缀在a到w的路径上(因为我们先考虑的是w两边都不要,这里其实我们要的是a到w的前一个节点的路径,c[i]考虑的也是这条路经)的情况数,即c[a](这里,因为我们是从小到大递推c[i],而a又小于i,故在求c[i]之前,我们必然已经推出过了c[a],直接拿来用),乘以s0[a+1,i]的子串匹配在w到根的路径上的序列的情况数(这个就是前面预处理的dp数组,拿来用即可)。求d[i]亦是同样地方法,这里时间复杂度主要是在预处理上。整体时间复杂度为n*l*l,问题得解。

代码:

#pragma comment(linker, "/STACK:1024000000,1024000000")
#include<stdio.h>
#include<string.h>
#include<algorithm>
#include<vector>
using namespace std ;
void get_num ( int& n ) {
    n = 0 ;
    char c ;
    while ( c = getchar () ) {
        if ( c >= '0' && c <= '9' ) break ;
    }
    n = c - '0' ;
    while ( c = getchar () ) {
        if ( c < '0' ¦¦ c > '9' ) break ;
        n = n * 10 + c -'0' ;
    }
}
const int maxn = 50005 ;
const int mod = 10007 ;
short dp[2][33][33][maxn] ;
int c[33] , d[33] ;
char s[maxn] , s1[33] ;

vector<int> vec[maxn] ;
int p[20][maxn] , fa[maxn] , deep[maxn] ;
struct LCA {
    void dfs ( int u ) {
        if ( u == 1 ) fa[u] = 0 ;
        p[0][u] = fa[u] ;
        deep[u] = deep[fa[u]] + 1 ;
        for ( int i = 1 ; i < 20 ; i ++ ) p[i][u] = p[i-1][p[i-1][u]] ;
        int sz = vec[u].size () ;
        for ( int i = 0 ; i < sz ; i ++ ) {
            int v = vec[u][i] ;
            if ( v == fa[u] ) continue ;
            fa[v] = u ;
            dfs ( v ) ;
        }
    }
    int father_k ( int u , int k ) {
        for ( int i = 0 ; i < 20 ; i ++ )
            if ( k & ( 1 << i ) )
                u = p[i][u] ;
        return u ;
    }
    int query ( int a , int b ) {
        if ( deep[a] > deep[b] ) swap ( a , b ) ;
        b = father_k ( b , deep[b] - deep[a] ) ;
        if ( a == b ) return a ;
        for ( int i = 19 ; i >= 0 ; i -- ) {
            if ( fa[a] == fa[b] ) break ;
            if ( p[i][a] != p[i][b] ) {
                a = p[i][a] ;
                b = p[i][b] ;
            }
        }
        return fa[a] ;
    }
} lca ;

int l ;
void dfs ( int u , int x , int c ) {
    for ( int i = x ; i <= l ; i ++ ) {
        dp[c][x][i][u] += dp[c][x][i][fa[u]] ;
        if ( dp[c][x][i][u] >= mod ) dp[c][x][i][u] -= mod ;
        if ( s[u] == s1[i] )
            dp[c][x][i][u] += dp[c][x][i-1][fa[u]] ;
        if ( dp[c][x][i][u] >= mod ) dp[c][x][i][u] -= mod ;
    }
    int sz = vec[u].size () ;
    for ( int i = 0 ; i < sz ; i ++ ) {
        int v = vec[u][i] ;
        if ( v == fa[u] ) continue ;
        dfs ( v , x , c ) ;
    }
}

void DP ( int n , int c ) {
    for ( int i = 0 ; i <= l + 1 ; i ++ ) {
        for ( int j = 0 ; j <= n ; j ++ ) {
            for ( int k = 0 ; k <= i ; k ++ )
                dp[c][k][i][j] = 0 ;
            if (i) dp[c][i][i-1][j] = 1 ;
        }
    }
    for ( int i = 1 ; i <= l ; i ++ )
        dfs ( 1 , i , c ) ;
}

int main () {
    int T , n , q ;
    scanf ( "%d" , &T ) ;
    while ( T -- ) {
        scanf ( "%d%d" , &n , &q ) ;
        for ( int i = 1 ; i <= n ; i ++ ) vec[i].clear () ;
        for ( int i = 1 ; i < n ; i ++ ) {
            int a , b ;
            get_num (a) ;
            get_num (b) ;
            vec[a].push_back (b) ;
            vec[b].push_back (a) ;
        }
        scanf ( "%s" , s + 1 ) ;
        scanf ( "%s" , s1 + 1 ) ;
        l = strlen ( s1 + 1 ) ;
        lca.dfs ( 1 ) ;
        reverse ( s1 + 1 , s1 + l + 1 ) ;
        DP ( n , 0 ) ;
        reverse ( s1 + 1 , s1 + l + 1 ) ;
        DP ( n , 1 ) ;
        while ( q -- ) {
            int a , b , x , y ;
            get_num (a) ;
            get_num (b) ;
            if ( a == b ) {
                if ( l == 1 && s[a] == s1[1] ) puts ( "1" ) ;
                else puts ( "0" ) ;
                continue ;
            }
            int w = lca.query ( a , b ) ;
            int ans = 0 ;
            memset ( c , 0 , sizeof ( c ) ) ;
            memset ( d , 0 , sizeof ( d ) ) ;
            for ( int i = 0 ; i <= l ; i ++ ) {
                c[i] = dp[0][l-i+1][l][a] ;
                d[i] = dp[1][l-i+1][l][b] ;
         //       printf ( "d[%d] = %d\n" , i , d[i] ) ;
                for ( int j = 0 ; j < i ; j ++ ) {
                    c[i] -= (c[j] * dp[0][l-i+1][l-j][w] % mod) ;
                    d[i] -= (d[j] * dp[1][l-i+1][l-j][w] % mod) ;
                    c[i] += mod ;
                    if ( c[i] >= mod ) c[i] -= mod ;
                    d[i] += mod ;
                    if ( d[i] >= mod ) d[i] -= mod ;
                }
            //    printf ( "c[%d] = %d , d[%d] = %d\n" , i , c[i] , i , d[i] ) ;
            }
            for ( int i = 0 ; i <= l ; i ++ ) {
                ans += c[i] * d[l-i] % mod ;
                if ( ans >= mod ) ans -= mod ;
            }
            for ( int i = 0 ; i < l ; i ++ ) {
                if ( s[w] == s1[i+1] ) {
                    ans += c[i] * d[l-i-1] % mod ;
                    if ( ans >= mod ) ans -= mod ;
                }
            }
            printf ( "%d\n" , ans ) ;
        }
    }
}
/*
1000
12 1000
1 2
1 3
2 4
2 5
2 6
5 9
5 10
3 7
3 8
8 11
8 12
abbaabbababb ba
8 6
2 10
10 2
1 2 9 0
1 2 10 0
1 2 9 1
1 2 10 1
*/