首页 > 代码库 > poj 3150 Cellular Automaton(矩阵快速幂)

poj 3150 Cellular Automaton(矩阵快速幂)

http://poj.org/problem?id=3150


大致题意:给出n个数,问经过K次变换每个位置上的数变为多少。第i位置上的数经过一次变换定义为所有满足 min( abs(i-j),n-abs(i-j) )<=d的j位置上的数字之和对m求余。


思路:

我们先将上述定义表示为矩阵

B = 

1 1 0 0 1
1 1 1 0 0
0 1 1 1 0
0 0 1 1 1
1 0 0 1 1

B[i][j] = 表示i与j满足上述关系,B[i][j] = 0表示i与j不满足上述关系。根据这个矩阵,那么样例1中1 2 2 1 2经过一次变换变成了5 5 5 5 4。


其实这也是矩阵相乘的问题,令A = 1 2 2 1 2,那么A * B = 5 5 5 5 4。那么要经过K次变换,答案无疑是 A*(B^k)mod m。

用矩阵快速幂的复杂度为 O(n^3 * log k),n最大是500,K也很大,必会TLE。logk是不会变了,优化在于n^3。仔细观察B矩阵,发现它是有规律的,它的每一行都是它上一行右移一位得到的。那么在矩阵相乘时,我们只需计算第一行,然后整个矩阵就算出来了,这样复杂度降为O(n^2 * log k)。


A这道题真是太坎坷了。在矩阵相乘时我传的两个参数是结构体,里面是500*500的数组,一运行就崩了,一直找找不到原因,最后发现传参的问题,它相当于直接把两个结构体传过去,显然太大了,随后就改成指针传参,后来因为没有释放内存,1MLE,再后来把k和d 输反了,1WA,最后终于2000+ms过了。。


#include <stdio.h>
#include <iostream>
#include <algorithm>
#include <set>
#include <map>
#include <vector>
#include <math.h>
#include <string.h>
#include <queue>
#include <string>
#include <stdlib.h>
#define LL long long
#define _LL __int64
#define eps 1e-8
#define PI acos(-1.0)
using namespace std;

const int INF = 0x3f3f3f3f;
const int maxn = 510;

_LL b[maxn],ans[maxn];
int n,m,k,d;
int mod;

struct matrix
{
    _LL mat[maxn][maxn];
}a,*res;

matrix *matrixMul(matrix *x, matrix *y)
{
    matrix *tmp;
    tmp = (matrix *)malloc(sizeof(matrix));
    memset((*tmp).mat,0,sizeof((*tmp).mat));
    for(int i = 0; i < 1; i++)
    {
        for(int k = 0; k < n; k++)
        {
            if( (*x).mat[i][k] == 0) continue;
            for(int j = 0; j < n; j++)
            {
                (*tmp).mat[i][j] += (*x).mat[i][k] * (*y).mat[k][j];
                if((*tmp).mat[i][j] >= mod)
                    (*tmp).mat[i][j] %= mod;
            }
        }
    }

	for(int i = 0; i < n; i++)
	{
		for(int j = 0; j < n; j++)
		{
			if(i == 0) (*x).mat[i][j] = (*tmp).mat[i][j];
			else (*x).mat[i][j] = (*x).mat[i-1][(j-1+n)%n];
		}
	}
	free(tmp);
    return x;
}

matrix *Mul(matrix *x, int k)
{
    matrix *tmp;
	tmp = (matrix *)malloc(sizeof(matrix));
	memset((*tmp).mat,0,sizeof((*tmp).mat));
	for(int i = 0; i < n; i++)
        (*tmp).mat[i][i] = 1;

    while(k)
    {
        if(k&1)
            tmp = matrixMul(tmp,x);
        x = matrixMul(x,x);
        k >>= 1;
    }
    return tmp;
}

int main()
{
	while(~scanf("%d %d %d %d",&n,&m,&d,&k))
	{
		mod = m;
		for(int i = 0; i < n; i++)
			scanf("%I64d",&b[i]);

		for(int i = 0; i < n; i++)
		{
			for(int j = 0; j < n; j++)
			{
				if(min (abs(i-j),n-abs(i-j)) <= d)
					a.mat[i][j] = 1;
				else a.mat[i][j] = 0;
			}
		}

		res = Mul(&a,k);

		memset(ans,0,sizeof(ans));

		for(int i = 0; i < n; i++)
		{
			for(int j = 0; j < n; j++)
			{
				ans[i] += b[j] * ((*res).mat[j][i]);
				if(ans[i] >= mod)
					ans[i] %= mod;
			}
		}
		for(int i = 0; i < n-1; i++)
			printf("%I64d ",ans[i]);
		printf("%I64d\n",ans[n-1]);
   }
    return 0;
}