首页 > 代码库 > zoj3494BCD Code(ac自动机+数位dp)

zoj3494BCD Code(ac自动机+数位dp)

l链接

这题想了好一会呢。。刚开始想错了,以为用自动机预处理出k长度可以包含的合法的数的个数,然后再数位dp一下就行了,写到一半发现不对,还要处理当前走的时候是不是为合法的,这一点无法移到trie树上去判断。

之后想到应该在trie树上进行数位dp,走到第i个节点且长度为j的状态是确定的,所以可以根据trie树上的节点来进行确定状态。

dp[i][j]表示当前节点为i,数第j位时可以包含多少个合法的数。

  1 #include <iostream>
  2 #include<cstdio>
  3 #include<cstring>
  4 #include<string>
  5 #include<algorithm>
  6 #include<stdlib.h>
  7 #include<vector>
  8 #include<cmath>
  9 #include<queue>
 10 #include<set>
 11 using namespace std;
 12 #define N 2010
 13 #define LL long long
 14 #define INF 0xfffffff
 15 const double eps = 1e-8;
 16 const double pi = acos(-1.0);
 17 const double inf = ~0u>>2;
 18 const int child_num = 2;
 19 const int mod = 1000000009;
 20 int dp[210][N];
 21 char s1[210],s2[210];
 22 class AC
 23 {
 24     private:
 25     int ch[N][child_num];
 26     int Q[N];
 27     int fail[N];
 28     int val[N];
 29     int id[127];
 30     int sz;
 31     int dd[810][N];
 32     public:
 33     void init()
 34     {
 35         fail[0] = 0;
 36         id[0] = 0;id[1] = 1;
 37     }
 38     void reset()
 39     {
 40         memset(val,0,sizeof(val));
 41         memset(ch[0],0,sizeof(ch[0]));
 42         sz=1;
 43     }
 44     void insert(char *a,int key)
 45     {
 46         int p =0 ;
 47         for( ; *a ; a++)
 48         {
 49             int d = id[*a];
 50             if(ch[p][d]==0){
 51                 memset(ch[sz],0,sizeof(ch[sz]));
 52                 ch[p][d] = sz++;
 53             }
 54             p = ch[p][d];
 55         }
 56         val[p] = key;
 57     }
 58     void construct()
 59     {
 60         int i,head=0,tail = 0;
 61         for(i = 0 ;i < child_num ; i++)
 62         {
 63             if(ch[0][i])
 64             {
 65                 fail[ch[0][i]] = 0;
 66                 Q[tail++] = ch[0][i];
 67             }
 68         }
 69         while(head!=tail)
 70         {
 71             int u = Q[head++];
 72             val[u]|=val[fail[u]];
 73             for(i =0 ;i < child_num ; i++)
 74             {
 75                 if(ch[u][i])
 76                 {
 77                     fail[ch[u][i]] = ch[fail[u]][i];
 78                     Q[tail++] = ch[u][i];
 79                 }
 80                 else ch[u][i] = ch[fail[u]][i];
 81             }
 82         }
 83     }
 84     int dfs(char *s,int i,int c,int e,int k)
 85     {
 86         if(i==-1)
 87         {
 88             return 1;
 89         }
 90         if(!e&&~dp[i][c])
 91         {
 92             return dp[i][c];
 93         }
 94         int mk = e?s[i]-0:9;
 95         int ans = 0;
 96         for(int j = 0; j <= mk ; j++)
 97         {
 98             if(!k&&j==0&&i)
 99             {
100                 ans = (ans+dfs(s,i-1,c,e&&j==mk,k));
101                 continue;
102             }
103             int p = c,flag = 1;
104             for(int g = 3; g >=0 ; g--)
105             {
106                 int o = (j&(1<<g))?1:0;
107                 p = ch[p][o];
108                 int tmp = p;
109                 while(tmp!=0)
110                 {
111                     if(val[tmp])
112                     {
113                         flag = 0;
114                         break;
115                     }
116                     tmp = fail[tmp];
117                 }
118                 if(!flag) break;
119             }
120             if(flag)
121             {
122                 ans = (ans+dfs(s,i-1,p,e&&j==mk,1))%mod;
123             }
124         }
125         return e?ans:dp[i][c] = ans;
126     }
127     void work(char *s1,char *s2)
128     {
129         memset(dp,-1,sizeof(dp));
130         printf("%d\n",(dfs(s2,strlen(s2)-1,0,1,0)-dfs(s1,strlen(s1)-1,0,1,0)+mod)%mod);
131     }
132 }ac;
133 char vir[22];
134 char ss1[210],ss2[210];
135 int main()
136 {
137     int t,n,i;
138     ac.init();
139     scanf("%d",&t);
140     while(t--)
141     {
142         ac.reset();
143         scanf("%d",&n);
144         while(n--)
145         {
146             scanf("%s",vir);
147             ac.insert(vir,1);
148         }
149         ac.construct();
150         scanf("%s%s",s1,s2);
151         int k = strlen(s1),kk= strlen(s2);
152         for(i = k-1 ; i >= 0; i--)
153         {
154             if(s1[i]>0)
155             {
156                 s1[i]-=1;
157                 break;
158             }
159             else
160             s1[i] = 9;
161         }
162         for(i = 0; i < k ; i++)
163         ss1[k-1-i] = s1[i];
164         ss1[k] = \0;
165         for(i = 0; i < kk ; i++)
166         ss2[kk-1-i] = s2[i];
167         ss2[kk] = \0;
168         ac.work(ss1,ss2);
169     }
170     return 0;
171 }
View Code