首页 > 代码库 > HDU 4945 2048(DP)
HDU 4945 2048(DP)
HDU 4945 2048
题目链接
题意:给定一个序列,求有多少个子序列能合成2048
思路:把2,4,8..2048这些数字拿出来考虑就可以了,其他数字无论如何都不能参与组成,那么在这些数字基础上,dp[i][j]表示到第i个数字,和为j的情况数,然后对于每个数枚举取多少个,就可以利用组合数取进行状态转移,这里有一个剪枝,就是如果加超过2048了,那么后面数字的组合数的和全部都是加到2048上面,可以利用公式一步求解,这样的总体复杂度就可以满足题目了。然后这题时限卡得紧啊,10W内的逆元不先预处理出来就超时。
代码:
#include <cstdio> #include <cstring> #include <algorithm> using namespace std; typedef long long ll; const int MOD = 998244353; inline void scanf_(int &num)//无负数 { char in; while((in=getchar()) > '9' || in<'0') ; num=in-'0'; while(in=getchar(),in>='0'&&in<='9') num*=10,num+=in-'0'; } int n, v[2049], mi[15], m, cnt[15]; int dp[15][2049], mi2[100005], invv[100005]; bool istwo[2049]; void init() { int num; m = 0; memset(cnt, 0, sizeof(cnt)); for (int i = 0; i < n; i++) { scanf_(num); if (!istwo[num]) { m++; continue; } else cnt[v[num]]++; } } int inv(int n) { int ans = 1; int k = MOD - 2; while (k) { if (k&1) ans = (ll)ans * n % MOD; n = (ll)n * n % MOD; k >>= 1; } return ans; } int solve() { memset(dp, 0, sizeof(dp)); dp[0][0] = 1; for (int i = 1; i <= 12; i++) { for (int j = 0; j <= 2048; j += mi[i]) { if (dp[i - 1][j] == 0) continue; int C = 1, s = 0; int sum = j; for (int k = 0; k <= cnt[i]; k++) { int x = sum; if (x == 2048) { dp[i][x] = (ll)dp[i - 1][j] * (mi2[cnt[i]] - s) % MOD + dp[i][x]; if (dp[i][x] < 0) dp[i][x] += MOD; if (dp[i][x] >= MOD) dp[i][x] -= MOD; break; } if (x % mi[i + 1]) x = x - mi[i]; dp[i][x] = (ll)dp[i - 1][j] * C % MOD + dp[i][x]; if (dp[i][x] >= MOD) dp[i][x] -= MOD; s += C; if (s >= MOD) s -= MOD; C = (ll)C * (cnt[i] - k) % MOD * invv[k + 1] % MOD; sum += mi[i]; } } } return (ll)dp[12][2048] * mi2[m] % MOD; } int main() { memset(istwo, false, sizeof(istwo)); memset(v, -1, sizeof(v)); mi[0] = 0; v[0] = 0; for (int i = 1, j = 1; i <= 2048; i *= 2, j++) { istwo[i] = true; v[i] = j; mi[j] = i; } mi[13] = 4096; for (int i = 1; i <= 2048; i++) { if (v[i] == -1) v[i] = v[i - 1]; } mi2[0] = 1; for (int i = 1; i <= 100000; i++) { invv[i] = inv(i); mi2[i] = mi2[i - 1] * 2 % MOD; } int cas = 0; while (~scanf("%d", &n) && n) { init(); printf("Case #%d: %d\n", ++cas, solve()); } return 0; }
声明:以上内容来自用户投稿及互联网公开渠道收集整理发布,本网站不拥有所有权,未作人工编辑处理,也不承担相关法律责任,若内容有误或涉及侵权可进行投诉: 投诉/举报 工作人员会在5个工作日内联系你,一经查实,本站将立刻删除涉嫌侵权内容。