首页 > 代码库 > hdu 4929 Another Letter Tree(LCA+DP)
hdu 4929 Another Letter Tree(LCA+DP)
hdu 4929 Another Letter Tree(LCA+DP)
题意:有一棵树n个节点(n<=50000),树上每个节点上有一个字母。m个询问(m<=50000),每次询问一个(a,b),问a节点到b节点的点不重复路径组成的字符串中子序列为s0的情况有多少种,s0长度小于等于30(注意s0是已经给定的,而不是每次询问都会给出一个新的)。
解法:一个很直观的想法,求出lca(设其为w)后,枚举x,求出a到w的路径上,能匹配s0的x长度前缀的情况有多少种,令其为c[x]。再求出b到w的路径上能匹配s0的L-x(L表示s0的长度)长度后缀的情况有多少种,令其为d[l-x],那么将所有的c[x]*d[l-x](x属于[0,l])加起来,即为答案(当然这里要考虑w这个点,不能同时出现在两部分当中,处理方法是w这个位置两部分都不要,然后在考虑w这个位置一定被选进去,两种情况加起来即可)。然后问题的难点在于,考虑某个节点u时,如何处理出c[i]与d[i]。这里,我们需要预处理一个dp数组,dp[i][j][u]表示,从u节点到根的路径匹配了s0[i,j]这段子串的子序列有多少种。那么c[i]就等于u到根的路径匹配了s0的i长度前缀情况数,减去有长度a的前缀在a到w的路径上(因为我们先考虑的是w两边都不要,这里其实我们要的是a到w的前一个节点的路径,c[i]考虑的也是这条路经)的情况数,即c[a](这里,因为我们是从小到大递推c[i],而a又小于i,故在求c[i]之前,我们必然已经推出过了c[a],直接拿来用),乘以s0[a+1,i]的子串匹配在w到根的路径上的序列的情况数(这个就是前面预处理的dp数组,拿来用即可)。求d[i]亦是同样地方法,这里时间复杂度主要是在预处理上。整体时间复杂度为n*l*l,问题得解。
代码:
#pragma comment(linker, "/STACK:1024000000,1024000000") #include<stdio.h> #include<string.h> #include<algorithm> #include<vector> using namespace std ; void get_num ( int& n ) { n = 0 ; char c ; while ( c = getchar () ) { if ( c >= '0' && c <= '9' ) break ; } n = c - '0' ; while ( c = getchar () ) { if ( c < '0' ¦¦ c > '9' ) break ; n = n * 10 + c -'0' ; } } const int maxn = 50005 ; const int mod = 10007 ; short dp[2][33][33][maxn] ; int c[33] , d[33] ; char s[maxn] , s1[33] ; vector<int> vec[maxn] ; int p[20][maxn] , fa[maxn] , deep[maxn] ; struct LCA { void dfs ( int u ) { if ( u == 1 ) fa[u] = 0 ; p[0][u] = fa[u] ; deep[u] = deep[fa[u]] + 1 ; for ( int i = 1 ; i < 20 ; i ++ ) p[i][u] = p[i-1][p[i-1][u]] ; int sz = vec[u].size () ; for ( int i = 0 ; i < sz ; i ++ ) { int v = vec[u][i] ; if ( v == fa[u] ) continue ; fa[v] = u ; dfs ( v ) ; } } int father_k ( int u , int k ) { for ( int i = 0 ; i < 20 ; i ++ ) if ( k & ( 1 << i ) ) u = p[i][u] ; return u ; } int query ( int a , int b ) { if ( deep[a] > deep[b] ) swap ( a , b ) ; b = father_k ( b , deep[b] - deep[a] ) ; if ( a == b ) return a ; for ( int i = 19 ; i >= 0 ; i -- ) { if ( fa[a] == fa[b] ) break ; if ( p[i][a] != p[i][b] ) { a = p[i][a] ; b = p[i][b] ; } } return fa[a] ; } } lca ; int l ; void dfs ( int u , int x , int c ) { for ( int i = x ; i <= l ; i ++ ) { dp[c][x][i][u] += dp[c][x][i][fa[u]] ; if ( dp[c][x][i][u] >= mod ) dp[c][x][i][u] -= mod ; if ( s[u] == s1[i] ) dp[c][x][i][u] += dp[c][x][i-1][fa[u]] ; if ( dp[c][x][i][u] >= mod ) dp[c][x][i][u] -= mod ; } int sz = vec[u].size () ; for ( int i = 0 ; i < sz ; i ++ ) { int v = vec[u][i] ; if ( v == fa[u] ) continue ; dfs ( v , x , c ) ; } } void DP ( int n , int c ) { for ( int i = 0 ; i <= l + 1 ; i ++ ) { for ( int j = 0 ; j <= n ; j ++ ) { for ( int k = 0 ; k <= i ; k ++ ) dp[c][k][i][j] = 0 ; if (i) dp[c][i][i-1][j] = 1 ; } } for ( int i = 1 ; i <= l ; i ++ ) dfs ( 1 , i , c ) ; } int main () { int T , n , q ; scanf ( "%d" , &T ) ; while ( T -- ) { scanf ( "%d%d" , &n , &q ) ; for ( int i = 1 ; i <= n ; i ++ ) vec[i].clear () ; for ( int i = 1 ; i < n ; i ++ ) { int a , b ; get_num (a) ; get_num (b) ; vec[a].push_back (b) ; vec[b].push_back (a) ; } scanf ( "%s" , s + 1 ) ; scanf ( "%s" , s1 + 1 ) ; l = strlen ( s1 + 1 ) ; lca.dfs ( 1 ) ; reverse ( s1 + 1 , s1 + l + 1 ) ; DP ( n , 0 ) ; reverse ( s1 + 1 , s1 + l + 1 ) ; DP ( n , 1 ) ; while ( q -- ) { int a , b , x , y ; get_num (a) ; get_num (b) ; if ( a == b ) { if ( l == 1 && s[a] == s1[1] ) puts ( "1" ) ; else puts ( "0" ) ; continue ; } int w = lca.query ( a , b ) ; int ans = 0 ; memset ( c , 0 , sizeof ( c ) ) ; memset ( d , 0 , sizeof ( d ) ) ; for ( int i = 0 ; i <= l ; i ++ ) { c[i] = dp[0][l-i+1][l][a] ; d[i] = dp[1][l-i+1][l][b] ; // printf ( "d[%d] = %d\n" , i , d[i] ) ; for ( int j = 0 ; j < i ; j ++ ) { c[i] -= (c[j] * dp[0][l-i+1][l-j][w] % mod) ; d[i] -= (d[j] * dp[1][l-i+1][l-j][w] % mod) ; c[i] += mod ; if ( c[i] >= mod ) c[i] -= mod ; d[i] += mod ; if ( d[i] >= mod ) d[i] -= mod ; } // printf ( "c[%d] = %d , d[%d] = %d\n" , i , c[i] , i , d[i] ) ; } for ( int i = 0 ; i <= l ; i ++ ) { ans += c[i] * d[l-i] % mod ; if ( ans >= mod ) ans -= mod ; } for ( int i = 0 ; i < l ; i ++ ) { if ( s[w] == s1[i+1] ) { ans += c[i] * d[l-i-1] % mod ; if ( ans >= mod ) ans -= mod ; } } printf ( "%d\n" , ans ) ; } } } /* 1000 12 1000 1 2 1 3 2 4 2 5 2 6 5 9 5 10 3 7 3 8 8 11 8 12 abbaabbababb ba 8 6 2 10 10 2 1 2 9 0 1 2 10 0 1 2 9 1 1 2 10 1 */