首页 > 代码库 > poj 3735 Training little cats 矩阵

poj 3735 Training little cats 矩阵

假设n=3

构造矩阵【1,0,0,0】

对于g 1操作,构造矩阵(0行i列++)

1 1 0 0

0 1 0 0

0 0 1 0

0 0 0 1

对于e 1操作,构造矩阵 (i整列清空)

1 0 0 0

0 0 0 0

0 0 1 0

0 0 0 1

对于s 1 2操作,构造矩阵 (i,j整列交换)

0 0 0

0 1 0

0 1 0 0

0 0 0 1

将k次操作依次按上述构造矩阵,得到一个轮回的转置矩阵。做m次快速幂就行了。

最坑的地方在于,答案要用longlong存,而longlong在做矩阵时,相乘次数太多会超时,看了discuss才知道加一行if判断不为0相乘就能过。

#include <cstdlib>
#include <cstring>
#include <cstdio>
#include <iostream>
using namespace std;
typedef long long ll;
struct matrix
{
       ll a[105][105];
       int n,m;
}origin,res,ans;
matrix multiply(matrix &x,matrix &y)
{
    matrix temp;
    temp.n=x.n;
    temp.m=y.m;
    for(int i=0;i<x.n;i++)
    {
        for(int j=0;j<y.m;j++)
        {
            temp.a[i][j]=0;
            for(int k=0;k<y.m;k++)
            {
                if(x.a[i][k]&&y.a[k][j])        //关键点,不加会超时,因为longlong相乘太多次了
                temp.a[i][j]+=x.a[i][k]*y.a[k][j];
            }
        }
    }
    return temp;
}
void init(int n)
{
    memset(origin.a,0,sizeof(origin.a));
    for(int i=0;i<=n;i++) origin.a[i][i]=1;
    origin.n=n+1;
    origin.m=n+1;
    memset(res.a,0,sizeof(res.a));
    res.a[0][0]=1;
    res.n=1;
    res.m=n+1;
}
void calc(ll n)
{
     if(n<=0) {ans=res;ans.a[0][0];return;}
     while(n>1)
     {
        if(n&1){n--;
        res=multiply(res,origin);}
        n>>=1;
        origin=multiply(origin,origin);
     }
     ans=multiply(res,origin);
}


int main()
{
    int m,n,k,x,y;
    char q;
    while(cin>>n>>m>>k)
    {
        if(m+n+k==0) break;
        init(n);
        for(int i=1;i<=k;i++)
        {
            cin>>q;
            if(q=='g')
            {
                cin>>x;
                origin.a[0][x]++;
            }
            else if(q=='e')
            {
                cin>>x;
                for(int i=0;i<=n;i++) origin.a[i][x]=0;
            }
            else
            {
                cin>>x>>y;
                for(int i=0;i<=n;i++) swap(origin.a[i][x],origin.a[i][y]);
            }

        }
        calc(m);
        for(int i=1;i<=n;i++)
        {
            if(i==1) printf("%lld",ans.a[0][i]);
            else printf(" %lld",ans.a[0][i]);
        }
        puts("");
    }
    return 0;
}