首页 > 代码库 > [BZOJ2734][HNOI2012]集合选数

[BZOJ2734][HNOI2012]集合选数

[BZOJ2734][HNOI2012]集合选数

试题描述

《集合论与图论》这门课程有一道作业题,要求同学们求出{1, 2, 3, 4, 5}的所有满足以下条件的子集:若 x 在该子集中,则 2x 和 3x 不能在该子集中。同学们不喜欢这种具有枚举性质的题目,于是把它变成了以下问题:对于任意一个正整数 n≤100000,如何求出{1, 2,..., n} 的满足上述约束条件的子集的个数(只需输出对 1,000,000,001 取模的结果),现在这个问题就交给你了。 

输入

只有一行,其中有一个正整数 n,30%的数据满足 n≤20。

输出

仅包含一个正整数,表示{1, 2,..., n}有多少个满足上述约束条件 的子集。 

输入示例

4

输出示例

8

数据规模及约定

有8 个集合满足要求,分别是空集,{1},{1,4},{2},{2,3},{3},{3,4},{4}。

题解

构造 + 状压 dp。

我们可以构造出这样一个矩阵:

1  3  9  27  81  ...

2  6  18 54  162  ...

4  12 36 108  324  ...

...  ...  ...  ...  ...

即,每个元素 x 的上方是 x / 2,下方是 2x,左边是 x / 3,右边是 3x,那么这个矩阵中选不相邻的数字就是一个合法方案,计算这个方案数用状压 dp 即可。

注意每列形如 a, 2a, 4a, ...,每行形如 b, 3b, 9b, ...,所以只要每行(列)的开头不一样,那么整行(列)就不会有重复数字;同理,只要两个同样构造方法的矩阵的左上角数字不一样,这两个矩阵的数字不会有交集。

那么我们对于所有剩下的数字的最小的作为左上角 dp 一下,每个矩阵的方案数乘积就是答案。

#include <iostream>
#include <cstdio>
#include <cstdlib>
#include <cstring>
#include <cctype>
#include <algorithm>
using namespace std;

int read() {
	int x = 0, f = 1; char c = getchar();
	while(!isdigit(c)){ if(c == ‘-‘) f = -1; c = getchar(); }
	while(isdigit(c)){ x = x * 10 + c - ‘0‘; c = getchar(); }
	return x * f;
}

#define maxn 100010
#define maxlog 18
#define maxs 2048
#define MOD 1000000001
#define LL long long

int n, len[maxlog], f[maxlog][maxs];
bool used[maxn];

int solve(int x) {
	memset(len, 0, sizeof(len));
	int N = 0;
	for(int j = 1; x <= n; j++, x *= 3)
		for(int i = 1, t = x; t <= n; i++, t <<= 1) {
			len[i] = max(len[i], j);
			used[t] = 1;
			N = max(N, i);
		}
	memset(f, 0, sizeof(f));
	f[0][0] = 1;
	for(int i = 1; i <= N; i++) {
		int all = (1 << len[i]) - 1, all1 = (1 << len[i-1]) - 1;
		for(int S = 0; S <= all; S++) {
			bool ok = 1;
			for(int j = 0; j < len[i] - 1; j++)
				if((S >> j & 1) && (S >> j + 1 & 1)){ ok = 0; break; }
			if(!ok) continue;
			for(int S1 = 0; S1 <= all1; S1++) if(f[i-1][S1] && !(S & S1)) {
				f[i][S] += f[i-1][S1];
				if(f[i][S] >= MOD) f[i][S] -= MOD;
			}
		}
	}
	int ans = 0;
	for(int S = 0; S <= (1 << len[N]) - 1; S++) {
		ans += f[N][S];
		if(ans >= MOD) ans -= MOD;
	}
	return ans;
}

int main() {
	n = read();
	
	int ans = 1;
	for(int i = 1; i <= n; i++) if(!used[i])
		ans = (LL)ans * solve(i) % MOD;
	printf("%d\n", ans);
	
	return 0;
}

这题不能用 fread。。。

[BZOJ2734][HNOI2012]集合选数