首页 > 代码库 > POJ 2778 AC自动机+矩阵幂 不错的题

POJ 2778 AC自动机+矩阵幂 不错的题

http://poj.org/problem?id=2778

有空再重新做下,对状态图的理解很重要

题解:
http://blog.csdn.net/morgan_xww/article/details/7834801


另外做了矩阵幂的模板:

//ac.sz是矩阵的大小
void mulmtr(long long x[MAXNODE][MAXNODE],long long y[MAXNODE][MAXNODE])//y=x*y
{
    ll tmp[MAXNODE][MAXNODE];
    for(int i=0;i<ac.sz;i++)
    {
        for(int j=0;j<ac.sz;j++)
        {
            tmp[i][j]=0;
            for(int k=0;k<ac.sz;k++)
                tmp[i][j] +=x[i][k]*y[k][j];
            tmp[i][j] %=MOD;
        }
    }
    for(int i=0;i<ac.sz;i++)
        for(int j=0;j<ac.sz;j++)
            y[i][j]=tmp[i][j];
}
void Mtrmi(ll mtr[MAXNODE][MAXNODE],int n)
{
    for(int i=0;i<ac.sz;i++)
    {
        for(int j=0;j<ac.sz;j++)
        {
            if(i == j)ans[i][j]=1;//E矩阵
            else ans[i][j]=0;
        }
    }
    while(n)
    {
        if(n&1)
        {
            mulmtr(mtr,ans);
        }
        mulmtr(mtr,mtr);
        n/=2;
    }
}




代码:

#include <cstdio>
#include <cstring>
#include <string>
#include <map>
#include <queue>
#include <iostream>
using namespace std;

#define ll long long

const int MAXNODE  = 15*15;
const int SSIZE  = 2000000000+100;
const int MOD = 100000;
const int SIGMA_SIZE = 4;
const int SIZE = 20;

ll mtr[MAXNODE][MAXNODE];
ll ans[MAXNODE][MAXNODE];
int danger[MAXNODE];

struct AC
{
    int f[MAXNODE];
    int val[MAXNODE];
    int last[MAXNODE];
    int cnt[MAXNODE];
    int ch[MAXNODE][SIGMA_SIZE];
    int sz;

    void init()
    {
        memset(ch[0],0,sizeof(ch[0]));
        memset(cnt,0,sizeof(cnt));
        f[0]=0;///////////
        sz=1;
    }

    inline int idx(char x)
    {
        if(x == 'A')return 0;
        if(x == 'T')return 1;
        if(x == 'C')return 2;
        if(x == 'G')return 3;
    }

    void insert(char *s, int v)
    {
        int n=strlen(s),u=0;
        for(int i=0;i<n;i++)
        {
            int id= idx(s[i]);
            if(!ch[u][id])
            {
                memset(ch[sz],0,sizeof(ch[sz]));
                val[sz]=0;
                ch[u][id]=sz++;
            }
            u=ch[u][id];
        }
        val[u]=v;
        danger[u]=1;////////
    }

    void getfail()
    {
        queue<int>q;
        f[0]=0;
        for(int c=0;c<SIGMA_SIZE;c++)
        {
            int u=ch[0][c];
            if(u)
            {
                q.push(u);
                f[u]=0;
                last[u]=0;
            }
        }
        while(!q.empty())
        {
            int r=q.front();q.pop();
            for(int c=0;c<SIGMA_SIZE;c++)
            {
                int u=ch[r][c];
                //if(!u)continue;////////
                if(!u)
                {
                    ch[r][c]=ch[f[r]][c];//////
                    continue;
                }
                q.push(u);
                int v=f[r];
                while(v &&!ch[v][c])v=f[v];
                f[u]=ch[v][c];
                //last[u]=val[f[u]]?f[u]:last[f[u]];
                danger[u] |= danger[f[u]];
            }
        }
    }
};
void init()
{
    memset(mtr,0,sizeof(mtr));
    memset(danger,0,sizeof(danger));
}

AC ac;

char str[SIZE];

void mulmtr(long long x[MAXNODE][MAXNODE],long long y[MAXNODE][MAXNODE])//y=x*y
{
    ll tmp[MAXNODE][MAXNODE];
    for(int i=0;i<ac.sz;i++)
    {
        for(int j=0;j<ac.sz;j++)
        {
            tmp[i][j]=0;
            for(int k=0;k<ac.sz;k++)
                tmp[i][j] +=x[i][k]*y[k][j];
            tmp[i][j] %=MOD;
        }
    }
    for(int i=0;i<ac.sz;i++)
        for(int j=0;j<ac.sz;j++)
            y[i][j]=tmp[i][j];
}
void Mtrmi(ll mtr[MAXNODE][MAXNODE],int n)
{
    for(int i=0;i<ac.sz;i++)
    {
        for(int j=0;j<ac.sz;j++)
        {
            if(i == j)ans[i][j]=1;//E矩阵
            else ans[i][j]=0;
        }
    }
    while(n)
    {
        if(n&1)
        {
            mulmtr(mtr,ans);
        }
        mulmtr(mtr,mtr);
        n/=2;
    }
}

int main()
{
    //freopen("poj2788.txt","r",stdin);
    int n,m;
    while(~scanf("%d%d",&m,&n))
    {
        init();
        ac.init();
        for(int i=1;i<=m;i++)
        {
            scanf("%s",str);
            ac.insert(str,i);
        }
        ac.getfail();
        for(int i=0;i<ac.sz;i++)
            if(!danger[i])
                for(int j=0;j<4;j++)
                    if(!danger[ac.ch[i][j]])
                    {
                        mtr[i][ac.ch[i][j]]++;
                    }

        Mtrmi(mtr,n);
                /////////////////////////////////
       /* for(int i=0;i<ac.sz;i++)
        {
            for(int j=0;j<ac.sz;j++)
                printf("%lld|%lld ",mtr[i][j],ans[i][j]);
            putchar('\n');
        }*/
        ///////////////////////
        for(int i=1;i<ac.sz;i++)
            ans[0][0]+=ans[0][i]%MOD;
        printf("%I64d\n",ans[0][0]%MOD);
    }
    return 0;
}