首页 > 代码库 > SPOJ 8059. Stocks Prediction (矩阵嵌矩阵)

SPOJ 8059. Stocks Prediction (矩阵嵌矩阵)

题目大意:

给出一个递推的关系。

这个递推的关系可以求出 s_1 s_2 s_3 .... s_m

然后再告诉一个 k 与 n

求出segma( s_k , s_2*k , s_3*k)...共n项。


思路分析:

首先给出来的是递推关系式。

所以可以用一个矩阵递推出 s [i]...

但是他要的是每隔k的值。

定义s的递推矩阵是 A

SUM = S_k + S_2*K+ .... + S_N*K
SUM = S_k + A^k * S_k + A^2k * S_k...

SUM = (E + A^k + A^2*k + A^3*k....A^(n-1) ) * s_k..

所以我们用KK 矩阵表示 A^k

就又把问题转化成了

求 E+ A^k +A^2*k + A^3*k ....的和了。

所以就再一定义个矩阵,这个矩阵的每个元素都是一个小矩阵。

然后继续递推求值。

递推的矩阵是

E E

0  A^K


#include <cstdio>
#include <iostream>
#include <cstring>
#include <iostream>

using namespace std;
typedef long long LL;
LL mod=1000000007;
LL N;
struct matrix//N*N
{
    LL data[10][10];
    friend matrix operator * (const matrix A,const matrix B)
    {
        matrix res;
        memset(res.data,0,sizeof res.data);
        for(int i=0;i<N;i++)
        for(int j=0;j<N;j++)
        for(int k=0;k<N;k++)
        {
            res.data[i][j]+=(A.data[i][k]*B.data[k][j])%mod;
            res.data[i][j]%=mod;
        }
        return res;
    }
    friend matrix operator + (const matrix A,const matrix B)
    {
        matrix res;
        for(int i=0;i<N;i++)
        for(int j=0;j<N;j++)
        {
            res.data[i][j]=(A.data[i][j]+B.data[i][j])%mod;
            res.data[i][j]%=mod;
        }
        return res;
    }
    friend matrix operator - (const matrix A,const matrix B)
    {
        matrix res;
        for(int i=0;i<N;i++)
        for(int j=0;j<N;j++)
        {
            res.data[i][j]=((A.data[i][j]-B.data[i][j])+mod)%mod;
            res.data[i][j]%=mod;
        }
        return res;
    }
    void print()
    {
        for(int i=0;i<N;i++)
        {
            for(int j=0;j<N;j++)
            printf("%lld ",data[i][j]);
            puts("");
        }
    }

}E,zero;
struct supermax
{
    matrix ret[10][10];
    friend supermax operator * (supermax A,supermax B)
    {
        supermax res;
        for(int i=0;i<2;i++)
        for(int j=0;j<2;j++)
        res.ret[i][j]=zero;

        for(int i=0;i<2;i++)
        for(int j=0;j<2;j++)
        for(int k=0;k<2;k++)
        {
            res.ret[i][j]=res.ret[i][j]+(A.ret[i][k]*B.ret[k][j]);
            for(int p=0;p<N;p++)
            for(int q=0;q<N;q++)
            res.ret[i][j].data[p][q]%=mod;
        }
        return res;
    }
};

matrix matmod(matrix origin,LL n)
{
    matrix res=E;

    while(n)
    {
        if(n&1)
        res=res*origin;
        n>>=1;
        origin=origin*origin;
    }
    return res;
}
supermax Do(supermax origin,LL n)//2*2
{
    supermax res;
    for(int i=0;i<2;i++)
    for(int j=0;j<2;j++)
    res.ret[i][j]=zero;
    for(int i=0;i<2;i++)
    res.ret[i][i]=E;

    while(n)
    {
        if(n&1)
        res=res*origin;
        n>>=1;
        origin=origin*origin;
    }
    return res;
}

LL S[10];
LL a_[10];

int main()
{
    memset(zero.data,0,sizeof zero.data);
    memset(E.data,0,sizeof E.data);
    LL n,r,k;
    int CASE;
    scanf("%d",&CASE);
    while(CASE--)
    {
        scanf("%lld%lld%lld",&n,&r,&k);
        N=r;

        for(int i=0;i<10;i++)
        E.data[i][i]=1;

        for(int i=1;i<=r;i++)
        {
            scanf("%lld",&S[i]);
            S[i]%=mod;
        }

        for(int i=1;i<=r;i++)
        {
            scanf("%lld",&a_[i]);
            a_[i]%=mod;
        }

        matrix init;
        for(int i=0;i<r;i++)
            init.data[i][0]=S[r-i];
        matrix fib;
        fib=zero;

        LL fans=0;
        LL fuck=1;
        while(fuck*k<=r)
        {
            fans+=S[fuck*k];
            fuck++;
        }
        LL b=fuck*k;
        for(int i=0;i<r;i++)
        {
            fib.data[0][i]=a_[i+1];
            fib.data[i+1][i]=1;
        }
        matrix st = matmod(fib,b-r);
        st=st*init;
        matrix K=matmod(fib,(LL)k);

        supermax o;
        o.ret[0][0]=E;
        o.ret[0][1]=E;
        o.ret[1][0]=zero;
        o.ret[1][1]=K;

        n=(n*k-b)/k+1;
        supermax final=Do(o,n);
        matrix tmp=(final.ret[0][0]*zero)+(final.ret[0][1]*E);

        matrix B=E;
        matrix ans = tmp*st;
        printf("%lld\n",(fans+ans.data[0][0])%mod);
    }
    return 0;
}
/*
5
10 5 7
1 2 3 4 5
1 2 3 4 5

10 5 7
5 4 3 2 1
5 4 3 2 1

10 5 7
123 456 789 987 654
6 78 9 7 6

10 5 7
34587 98237598 123134 523454 5243
2598 73897 54897 8978979 31242542

1 5 7
1 2 3 4 5
1 2 3 4 5

5
10 7 100
4890 78678 6876 54465 798798 567576 89412
3545 979123124 789475 32789 6786786 5675675 89789
*/