首页 > 代码库 > HDU 4652 Dice (概率DP)
HDU 4652 Dice (概率DP)
Dice
0 m n: ask for the expected number of tosses until the last n times results are all same.
1 m n: ask for the expected number of tosses until the last n consecutive results are pairwise different.
6 0 6 1 0 6 3 0 6 5 1 6 2 1 6 4 1 6 6 10 1 4534 25 1 1232 24 1 3213 15 1 4343 24 1 4343 9 1 65467 123 1 43434 100 1 34344 9 1 10001 15 1 1000000 2000
1.000000000 43.000000000 1555.000000000 2.200000000 7.600000000 83.200000000 25.586315824 26.015990037 15.176341160 24.541045769 9.027721917 127.908330426 103.975455253 9.003495515 15.056204472 4731.706620396
题目大意:
n边形的骰子,问你出现连续相同(不同)n次需要掷的次数的数学期望。
解题思路:
利用递归方式的DP的思想推公式
解题代码:(1)若询问为0,则:
dp[i] 记录的是已经连续i个相同,到n个不同需要的次数的数学期望
dp[0]= 1+dp[1]
dp[1]= 1+( 1/m*dp[2]+(m-1)/m*dp[1])=1+(dp[2]+(1-m)*dp[1])/m;
dp[2]= 1+(dp[3]+(1-m)*dp[1])/m;
....................
dp[n]= 0
推出:
dp[i] = 1 + ( (m-1)*dp[1] + dp[i+1] ) / m
dp[i+1] = 1 + ( (m-1)*dp[1] + dp[i+2] ) / m
因此,m*(dp[i+1]-dp[i])=(dp[i+2]-dp[i+1])
我们发现是等比数列
dp[0]-dp[1]=1;
dp[1]-dp[2]=m;
..........
dp[n-1]-dp[n]=m^(n-1)
累加,得:dp[0]-dp[n]=1+m+m^2+..........m^(n-1)=(1-m^n)/(1-m)
所以:dp[0]=(1-m^n)/(1-m);
(2)若询问为1,则:
dp[0] = 1 + dp[1]
dp[1] = 1 + (dp[1] + (m-1) dp[2]) / m
dp[2] = 1 + (dp[1] + dp[2] + (m-2) dp[3]) / m
dp[i] = 1 + (dp[1] + dp[2] + ... dp[i] + (m-i)*dp[i+1]) / m
dp[i+1]= 1 + (dp[1] + dp[2] + ... dp[i] + dp[i+1] + (m-i-1)*dp[i+1]) / m
...
dp[n] = 0;
选出 dp[i] 和 dp[i+1] 这两行相减 得
dp[i] - dp[i+1] = (m-i-1)/m * (dp[i+1] - dp[i+2]);
因此 dp[i+1] - dp[i+2] = m/(m-i-1)*(dp[i]-dp[i+1]);
所以:
dp[0]-dp[1]=1;
dp[1]-dp[2]=1*m/(m-1);
dp[2]-dp[3]=1*m/(m-1)*m/(m-2);
..........
dp[n-1]-dp[n]=1*m/(m-1)*m/(m-2)*.......*m/(m-n+1);
累加得到答案
#include <iostream> #include <cstdio> #include <cmath> using namespace std; inline double solve(){ int op,m,n; scanf("%d%d%d",&op,&m,&n); double ans=0; if(op==0){ for(int i=0;i<=n-1;i++){ ans+=pow(1.0*m,i); } }else{ double tmp=1.0; for(int i=1;i<=n;i++){ ans+=tmp; tmp*=m*1.0/(m-i); } } return ans; } int main(){ int t; while(scanf("%d",&t)!=EOF){ while(t-- >0){ printf( "%.9lf\n",solve() ); } } return 0; }