首页 > 代码库 > hdu 4878 ZCC loves words(AC自动机+dp+矩阵快速幂+中国剩余定理)

hdu 4878 ZCC loves words(AC自动机+dp+矩阵快速幂+中国剩余定理)

hdu 4878 ZCC loves words(AC自动机+dp+矩阵快速幂+中国剩余定理)

题意:给出若干个模式串,总长度不超过40,对于某一个字符串,它有一个价值,对于这个价值的计算方法是这样的,设初始价值为V=1,假如这个串能匹配第k个模式串,则V=V*prime[k]*(i+len[k]),其中prime[k]表示第k个素数,i表示匹配的结束位置,len[k]表示第k个模式串的长度(注意,一个字符串可以多次匹配同意个模式串)。问字符集为‘A‘-‘Z‘的字符,组成的所有的长为L的字符串,能得到的总价值和是多少?

解法:跟以前做过的很多AC自动机的题有点类似,很容易想到一个node*L的dp,dp[i][v]表示长为i,匹配到AC自动机的V节点能得到的价值和(详见代码中的DEBUG函数)。但是L太大,没法搞。节点总数只有40,那么就可以用矩阵来加速dp了,但是很可惜,建立矩阵的时候,发现建的矩阵居然是跟i有关,这样是不能直接用矩阵快速幂做的。但是,题目给出的提示是,mod可以拆成三个较小的质数。那么我们可以分别用三个较小的质数作为mod进行运算,因为第i个矩阵,它是跟第i+mod个矩阵一样的,所以我们可以把L个矩阵分成L/mod段,每一段的矩阵乘起来都是一样的,设其为A(可以暴力乘起来,因为mod很小),那么我们要的所有的L个矩阵的乘起来得到的矩阵,就是A^(L/mod),再乘上剩下来多余的L%mod个了,这样就可以计算出在每个较小的模系下的答案。最后用中国剩余定理计算总的答案。

代码:

#include<stdio.h>
#include<string.h>
#include<algorithm>
#include<vector>
#include<queue>
#define ll __int64
using namespace std ;

const int N = 44 ;
const int mod = 5047621 ;
int pri[12345] , p_num , vis[12345] ;

void get_prime () {
    p_num = 0 ;
    for ( int i = 2 ; i < 12345 ; i ++ ) {
        if ( !vis[i] ) pri[++p_num] = i ;
        for ( int j = 1 ; j <= p_num ; j ++ ) {
            if ( i * pri[j] >= 12345 ) break ;
            vis[i*pri[j]] = 1 ;
            if ( i % pri[j] == 0 ) break ;
        }
    }
}

struct Point {
    int len , p ;
    Point () {}
    Point ( int a , int b ):len(a),p(b) {}
} ;
struct RECT {
    int elem[N][N] ;
    void print ( int n ) {
        for ( int i = 0 ; i < n ; i ++ , puts ( "" ) )
            for ( int j = 0 ; j < n ; j ++ )
                printf ( "%d " , elem[i][j] ) ;
    }
} p[222] , E ;
struct AC_auto {
    int dp[111][44] ;
    int c[N][26] , fail[N] , tot ;
    vector<Point> vec[N] ;
    queue<int> Q ;
    void init () {
        tot = 0 ;
        new_node () ;
    }
    int new_node () {
        vec[tot].clear () ;
        fail[tot] = 0 ;
        memset ( c[tot] , 0 , sizeof ( c[tot] ) ) ;
        return tot ++ ;
    }
    void insert ( char *s , int i ) {
        int now = 0 , len = strlen ( s ) ;
        for ( ; *s ; s ++ ) {
            int k = *s - 'A' ;
            if ( !c[now][k] ) c[now][k] = new_node () ;
            now = c[now][k] ;
        }
        vec[now].push_back ( Point ( len , pri[i] ) ) ;
    }
    void get_fail () {
        int u = 0 , v ;
        for ( int i = 0 ; i < 26 ; i ++ ) {
            if ( c[u][i] )
                Q.push ( c[u][i] ) ;
        }
        while ( !Q.empty () ) {
            u = Q.front () ; Q.pop () ;
            for ( int i = 0 ; i < 26 ; i ++ ) {
                if ( c[u][i] ) {
                    v = c[u][i] ;
                    fail[v] = c[fail[u]][i] ;
                    Q.push ( v ) ;
                } else c[u][i] = c[fail[u]][i] ;
            }
        }
    }
    void BUILD_RECT ( int l , int mod ) {
        memset ( p[l].elem , 0 , sizeof ( p[l].elem ) ) ;
        for ( int i = 0 ; i < tot ; i ++ ) {
            for ( int j = 0 ; j < 26 ; j ++ ) {
                int u = c[i][j] ;
                int v = u , ret = 1 ;
                while ( v ) {
                    for ( int k = 0 ; k < vec[v].size () ; k ++ ) {
                        Point u = vec[v][k] ;
                        ret *= (l+u.len)*u.p ;
                        ret %= mod ;
                    }
                    v = fail[v] ;
                }
                p[l].elem[u][i] += ret ;
                if ( p[l].elem[u][i] >= mod )
                    p[l].elem[u][i] -= mod ;
            }
        }
    }
    void RECT_MUIL ( RECT x , RECT y , RECT &z , int mod ) {
        memset ( z.elem , 0 , sizeof ( z.elem ) ) ;
        for ( int i = 0 ; i < tot ; i ++ ) {
            for ( int j = 0 ; j < tot ; j ++ )
                for ( int k = 0 ; k < tot ; k ++ ) {
                    z.elem[i][j] += x.elem[i][k] * y.elem[k][j] % mod ;
                    if ( z.elem[i][j] >= mod )
                        z.elem[i][j] -= mod ;
                }
        }
    }
    void GAO ( RECT& ret , ll n , int mod ) {
  //      printf ( "n = %I64d\n" , n ) ;
        RECT f = ret ; ret = E ;
        while ( n ) {
            if ( n & 1 ) RECT_MUIL ( ret , f , ret , mod ) ;
            RECT_MUIL ( f , f , f , mod ) ;
            n >>= 1 ;
        }
    }
    int SOLVE ( int mod , ll l ) {
        RECT ans = E , temp = E ;
    //    printf ( "mod = %d\n" , mod ) ;
        for ( int i = mod ; i >= 1 ; i -- ) {
            BUILD_RECT ( i , mod ) ;
            RECT_MUIL ( temp , p[i] , temp , mod ) ;
        //    if (i == 1) ans.print ( tot ) ;
        }
   //     puts( "fuck ") ;
        GAO ( temp , l/mod , mod ) ;
   //     ans.print ( tot ) ;
        for ( int i = l % mod ; i >= 1 ; i -- ) {
            BUILD_RECT ( i , mod ) ;
            RECT_MUIL ( ans , p[i] , ans , mod ) ;
        }
        RECT_MUIL ( ans , temp , ans , mod ) ;
    //    ans.print ( tot ) ;
        int ret = 0 ;
        for ( int i = 0 ; i < tot ; i ++ ) {
            ret += ans.elem[i][0] ;
            if ( ret >= mod ) ret -= mod ;
        }
        return ret ;
    }
    void DEBUG ( ll l ) {
        memset ( dp , 0 , sizeof ( dp ) ) ;
        dp[0][0] = 1 ;
        for ( int i = 0 ; i < l ; i ++ ) {
            for ( int j = 0 ; j < tot ; j ++ ) {
                for ( int k = 0 ; k < 26 ; k ++ ) {
                    int u = c[j][k] ;
                    int v = u ;
                    int ret = 1 ;
                    while ( v ) {
                        for ( int g = 0 ; g < vec[v].size () ; g ++ ) {
                            Point f = vec[v][g] ;
                            ret *= (i+1+f.len) * f.p ;
                        }
                        v = fail[v] ;
                    }
                    dp[i+1][u] += dp[i][j] * ret % mod ;
                    if ( dp[i+1][u] >= mod ) dp[i+1][u] -= mod ;
                }
            }
        }
        int ans = 0 ;
        for ( int i = 0 ; i < tot ; i ++ ) {
            ans += dp[l][i] ;
            if ( ans >= mod ) ans -= mod ;
        }
        puts ( "fuck" ) ;
        printf ( "%d\n" , ans ) ;
    }
} ac ;

void extend_gcd ( ll a , ll b , int &x , int &y ) {
    if ( !b ) x = 1 , y = 0 ;
    else extend_gcd ( b , a % b , y , x ) , y -= x * ( a / b ) ;
}

char s[1111] ;
int main () {
    for ( int i = 0 ; i < N ; i ++ )
        for ( int j = 0 ; j < N ; j ++ )
            E.elem[i][j] = i == j ;
    get_prime () ;
    int n ; ll l ;
    int ca = 0 ;
    while ( scanf ( "%d%I64d" , &n , &l ) != EOF ) {
        ac.init () ;
        for ( int i = 1 ; i <= n ; i ++ ) {
            scanf ( "%s" , s ) ;
            ac.insert ( s , i ) ;
        }
        ac.get_fail () ;
    //    ac.DEBUG ( l ) ;
        int m1 , mm1 , m2 , mm2 , m3 , mm3 , fuck ;//mm为m的乘法逆元
        m1 = 173 * 179 , m2 = 163 * 179 , m3 = 163 * 173 ;
        extend_gcd ( m1 , 163 , mm1 , fuck ) ;
        extend_gcd ( m2 , 173 , mm2 , fuck ) ;
        extend_gcd ( m3 , 179 , mm3 , fuck ) ;
        int a1 = ac.SOLVE ( 163 , l ) ;
    //    printf( "a1 = %d\n"  , a1 ) ;
        int a2 = ac.SOLVE ( 173 , l ) ;
  //      printf ( "a2 = %d\n" , a2 ) ;
        int a3 = ac.SOLVE ( 179 , l ) ;
 //       printf ( "a3 = %d\n" , a3 ) ;
        int ans = ( a1 * m1 * mm1 + a2 * m2 * mm2 + a3 * m3 * mm3 ) % 5047621 ;
        printf ( "Case #%d: %d\n" , ++ ca , ans ) ;
    }
    return 0 ;
}
/*
2 3
AB
BB
2 2
A
B
*/