首页 > 代码库 > poj 2154 Color(polya计数 + 欧拉函数优化)

poj 2154 Color(polya计数 + 欧拉函数优化)

http://poj.org/problem?id=2154


大致题意:由n个珠子,n种颜色,组成一个项链。要求不同的项链数目,旋转后一样的属于同一种,结果模p。


n个珠子应该有n种旋转置换,每种置换的循环个数为gcd(i,n)。如果直接枚举i,显然不行。但是我们可以缩小枚举的数目。改为枚举每个循环节的长度L,那么相应的循环节数是n/L。所以我们只需求出每个L有多少个i满足gcd(i,n)= n/L,就得到了循环节数为n/L的个数。重点就是求出这样的i的个数。


令cnt = gcd(i,n) = n/L;

那么cnt | i,令i = cnt*t(0 <= t <= L);

又 n = cnt * L ;

所以gcd(i,n) = gcd( cnt*t, cnt*L) = cnt,

满足上式的条件是 gcd(t,L) = 1。

而这样的t 有Eular(L)个。

因此循环节个数是n/L的置换个数有Eular(L)个。

参考博客:http://blog.csdn.net/tsaid/article/details/7366708


代码中求欧拉函数是基于素数筛的,素数只需筛到sqrt(1e9)即可。我在筛素数的同时递推的记录了sqrt(1e9)以内的Eular(n),用phi[]表示。这样会快那么一点点。


#include <stdio.h>
#include <iostream>
#include <algorithm>
#include <set>
#include <map>
#include <vector>
#include <math.h>
#include <string.h>
#include <queue>
#include <string>
#include <stdlib.h>
#define LL long long
#define _LL __int64
#define eps 1e-8
#define PI acos(-1.0)
using namespace std;

const int maxn = 35000;
const int INF = 0x3f3f3f3f;

int n,p;
int ans;
int prime[maxn];
int flag[maxn];
int prime_num;
int phi[maxn];

int mod_exp(int a, int b, int c)
{
	int res = 1;
	a = a%c;
	while(b)
	{
		if(b&1)
			res = (res*a)%c;
		a = (a*a)%c;
		b >>= 1;
	}
	return res;
}

//素数筛并记录maxn以内的Eular(n),用phi[]表示
void get_prime()
{
	memset(flag,0,sizeof(flag));
	prime_num = 0;
	phi[1] = 1;
	for(int i = 2; i <= maxn; i++)
	{
		if(!flag[i])
		{
			prime[++prime_num] = i;
			phi[i] = i-1;
		}

		for(int j = 1; j <= prime_num && i*prime[j] <= maxn; j++)
		{
			flag[i*prime[j]] = 1;
			if(i % prime[j] == 0)
				phi[i*prime[j]] = phi[i] * prime[j];
			else phi[i*prime[j]] = phi[i] * (prime[j]-1);
		}
	}
}

int Eular(int n)
{
	if(n < maxn)
		return phi[n] % p;
	//求大于maxn的Eular(n)
	int res = n;
	for(int i = 1; prime[i]*prime[i] <= n && i <= prime_num; i++)
	{
		if(n % prime[i] == 0)
		{
			res -= res/prime[i];
			while(n%prime[i] == 0)
				n = n/prime[i];
		}
	}
	if(n > 1)
		res -= res/n;
	return res%p;
}

int main()
{

	int test;
	get_prime();
	scanf("%d",&test);

	while(test--)
	{
		scanf("%d %d",&n,&p);
		ans = 0;
		for(int l = 1; l*l <= n; l++)
		{
			if(l*l == n)
			{
				ans = (ans + Eular(l)*mod_exp(n,l-1,p))%p;
			}
			else if(n%l == 0) //循环节长度为l,那么n/l也是循环节长度
			{
				ans = (ans + Eular(l)*mod_exp(n,n/l-1,p))%p;
				ans = (ans + Eular(n/l)*mod_exp(n,l-1,p))%p;
			}
		}
		printf("%d\n",ans);
	}
	return 0;
}