首页 > 代码库 > poj 3744 矩阵加速--概率DP

poj 3744 矩阵加速--概率DP

http://poj.org/problem?id=3744


犯二了,,递推式,矩阵幂什么都会,但是我推得跟别人不一样,,,应该是对矩阵理解问题,,,再看看

#include <cstdio>
#include <cstring>
#include <algorithm>
#include <string>
#include <iostream>
#include <iomanip>
#include <cmath>
#include <map>
#include <set>
#include <queue>
using namespace std;

#define ls(rt) rt*2
#define rs(rt) rt*2+1
#define ll long long
#define ull unsigned long long
#define rep(i,s,e) for(int i=s;i<e;i++)
#define repe(i,s,e) for(int i=s;i<=e;i++)
#define CL(a,b) memset(a,b,sizeof(a))
#define IN(s) freopen(s,"r",stdin)
#define OUT(s) freopen(s,"w",stdout)
const int MAXN=3;
double mtr[MAXN][MAXN];
double ansm[MAXN][MAXN];
int sz=2;
void mulmtr(double x[MAXN][MAXN], double y[MAXN][MAXN])
{
    double tmp[MAXN][MAXN];
    for(int i=0;i<sz;i++)
        for(int j=0;j<sz;j++)
        {
            tmp[i][j]=0;
            for(int k=0;k<sz;k++)
                tmp[i][j]+=x[i][k]*y[k][j];
        }

    for(int i=0;i<sz;i++)
        for(int j=0;j<sz;j++)
            y[i][j]=tmp[i][j];
}

void mtrmi(double mtr[MAXN][MAXN],int n)
{
    for(int i=0;i<sz;i++)
        for(int j=0;j<sz;j++)
        {
            if(i == j)ansm[i][j]=1;
            else ansm[i][j]=0;
        }
    while(n)
    {
        if(n&1)
        {
            mulmtr(mtr,ansm);
        }
        mulmtr(mtr,mtr);
        n/=2;
    }
}
int bomb[MAXN*10];
int main()
{
    //IN("poj3744.txt");
    int n;
    int last;
    double p;
    double ans=1.0;
    int pos;
    while(~scanf("%d%lf",&n,&p))
    {
        ans=1.0;
        last=1;
        for(int i=0;i<n;i++)
            scanf("%d",&bomb[i]);
        sort(bomb,bomb+n);
        for(int i=0;i<n;i++)
        {
            pos=bomb[i]-last;
            /*if(pos == 0){last=pos+1;continue;}
            if(pos == 1){ans*=(1-p);last=pos+1;continue;}
            if(pos == 2){ans*=p;last=pos+1;continue;}*/
            //pos-=2;
            mtr[0][0]=p;
            mtr[0][1]=1.0-p;
            mtr[1][0]=1.0;
            mtr[1][1]=0.0;
            mtrmi(mtr,pos);
            ans*=1.0-ansm[0][0];
            last=bomb[i]+1;
        }
        cout << fixed << setprecision(7) << ans << endl;
    }
    return 0;
}