首页 > 代码库 > bnu 12639 Cards (dp求期望)

bnu 12639 Cards (dp求期望)

bnu 12639 Cards

dp求期望

区分 全局最优选择 和 当前最优选择。

本题是当前最优选择。

状态表示:

double dp[16][16][16][16][5][5];
bool vis[16][16][16][16][5][5];

状态下参数:

vector<int> up, vector<int> tmp。

so,记忆化搜索 + 回溯


//#pragma warning (disable: 4786)
//#pragma comment (linker, "/STACK:16777216")
//HEAD
#include <cstdio>
#include <ctime>
#include <cstdlib>
#include <cstring>
#include <queue>
#include <string>
#include <set>
#include <stack>
#include <map>
#include <cmath>
#include <vector>
#include <iostream>
#include <algorithm>
using namespace std;
//LOOP
#define FE(i, a, b) for(int i = (a); i <= (b); ++i)
#define FD(i, b, a) for(int i = (b); i>= (a); --i)
#define REP(i, N) for(int i = 0; i < (N); ++i)
#define CLR(A,value) memset(A,value,sizeof(A))
#define CPY(a, b) memcpy(a, b, sizeof(a))
#define FC(it, c) for(__typeof((c).begin()) it = (c).begin(); it != (c).end(); it++)
//INPUT
#define RI(n) scanf("%d", &n)
#define RII(n, m) scanf("%d%d", &n, &m)
#define RIII(n, m, k) scanf("%d%d%d", &n, &m, &k)
#define RS(s) scanf("%s", s)
//OUTPUT
#define WI(n) printf("%d\n", n)
#define WS(s) printf("%s\n", s)

typedef long long LL;
const int INF = 1000000007;
const double eps = 1e-10;
const int MAXN = 1010;

vector<int> goal;
double dp[16][16][16][16][5][5];
bool vis[16][16][16][16][5][5];

inline bool check(vector<int> &tmp)
{
    for (int i = 0; i < 4; i++)
        if (goal[i] > tmp[i]) return false;
    return true;
}
double dfs(vector<int> &tmp, vector<int> &up)
{
    if (check(tmp)) return 0.0;
    if (vis[tmp[0]][tmp[1]][tmp[2]][tmp[3]][tmp[4]][tmp[5]])
        return dp[tmp[0]][tmp[1]][tmp[2]][tmp[3]][tmp[4]][tmp[5]];
    vis[tmp[0]][tmp[1]][tmp[2]][tmp[3]][tmp[4]][tmp[5]] = true;
    double ret = 0;
    int tot = 0;
    REP(i, 4) tot += tmp[i]; tot = 54 - tot;
    REP(i, 4)
    {
        if (up[i] > tmp[i])
        {
            tmp[i]++;
            ret += (up[i] - tmp[i] + 1.0) / tot * dfs(tmp, up);
            tmp[i]--;
        }
    }

    if (up[4])
    {
        double retmin = 100;
        REP(i, 4)
        {
            tmp[i]++; up[i]++; up[4]--; tmp[4] = i + 1;
            retmin = min(retmin, 1.0 / tot * dfs(tmp, up));
            tmp[i]--; up[i]--; up[4]++; tmp[4] = 0;
        }
        ret += retmin;
    }
    if (up[5])
    {
        double retmin = 100;
        REP(i, 4)
        {
            tmp[i]++; up[i]++; up[5]--; tmp[5] = i + 1;
            retmin = min(retmin, 1.0 / tot * dfs(tmp, up));
            tmp[i]--; up[i]--; up[5]++; tmp[5] = 0;
        }
        ret += retmin;
    }
    dp[tmp[0]][tmp[1]][tmp[2]][tmp[3]][tmp[4]][tmp[5]] = ++ret;
    return ret;
}

int main ()
{
    int T;
    int ncase = 1;
    RI(T);
    while (T--)
    {
        int xx = 0, tot = 0;
        goal.clear();
        REP(i, 4)
        {
            int x; RI(x); goal.push_back(x);
            if (x > 13) xx += x - 13;
            tot += x;
        }
        double ans;
        if (tot > 54 || xx > 2)
        {
            ans = -1.000;
        }
        else
        {
            memset(vis, 0, sizeof(vis));
            vector<int> tmp, up;
            REP(i, 6) tmp.push_back(0);
            REP(i, 4) up.push_back(13); REP(i, 2) up.push_back(1);
            ans = dfs(tmp, up);
            if (ans < 0) ans = 0.000;
        }
        printf("Case %d: %.3lf\n", ncase++, ans);
    }
    return 0;
}


稍微有点不同的写法

//#pragma warning (disable: 4786)
//#pragma comment (linker, "/STACK:16777216")
//HEAD
#include <cstdio>
#include <ctime>
#include <cstdlib>
#include <cstring>
#include <queue>
#include <string>
#include <set>
#include <stack>
#include <map>
#include <cmath>
#include <vector>
#include <iostream>
#include <algorithm>
using namespace std;
//LOOP
#define FE(i, a, b) for(int i = (a); i <= (b); ++i)
#define FD(i, b, a) for(int i = (b); i>= (a); --i)
#define REP(i, N) for(int i = 0; i < (N); ++i)
#define CLR(A,value) memset(A,value,sizeof(A))
#define CPY(a, b) memcpy(a, b, sizeof(a))
#define FC(it, c) for(__typeof((c).begin()) it = (c).begin(); it != (c).end(); it++)
//INPUT
#define RI(n) scanf("%d", &n)
#define RII(n, m) scanf("%d%d", &n, &m)
#define RIII(n, m, k) scanf("%d%d%d", &n, &m, &k)
#define RS(s) scanf("%s", s)
//OUTPUT
#define WI(n) printf("%d\n", n)
#define WS(s) printf("%s\n", s)

typedef long long LL;
const int INF = 1000000007;
const double eps = 1e-10;
const int MAXN = 1010;

vector<int> goal;
double dp[14][14][14][14][5][5];
bool vis[14][14][14][14][5][5];

inline bool check(vector<int> &tmp)
{
    int xx = tmp[4], yy = tmp[5];
    if (xx) tmp[xx - 1]++; if (yy) tmp[yy - 1]++;

    bool ret = true;
    for (int i = 0; i < 4; i++)
        if (goal[i] > tmp[i]) ret = false;

    if (xx) tmp[xx - 1]--; if (yy) tmp[yy - 1]--;
    return ret;
}
double dfs(vector<int> &tmp)
{
    if (check(tmp)) return 0.0;
    if (vis[tmp[0]][tmp[1]][tmp[2]][tmp[3]][tmp[4]][tmp[5]])
        return dp[tmp[0]][tmp[1]][tmp[2]][tmp[3]][tmp[4]][tmp[5]];
    vis[tmp[0]][tmp[1]][tmp[2]][tmp[3]][tmp[4]][tmp[5]] = true;
    double ret = 0;

    int tot = 0;
    REP(i, 4) tot += tmp[i];
    int xx = tmp[4], yy = tmp[5];
    if (xx) tot++; if (yy) tot++;
    tot = 54 - tot;

    REP(i, 4)
    {
        if (13 > tmp[i])
        {
            tmp[i]++;
            ret += (13 - tmp[i] + 1.0) / tot * dfs(tmp);
            tmp[i]--;
        }
    }

    if (!xx)
    {
        double retmin = 100;
        REP(i, 4)
        {
            tmp[4] = i + 1;
            retmin = min(retmin, 1.0 / tot * dfs(tmp));
            tmp[4] = 0;
        }
        ret += retmin;
    }
    if (!yy)
    {
        double retmin = 100;
        REP(i, 4)
        {
            tmp[5] = i + 1;
            retmin = min(retmin, 1.0 / tot * dfs(tmp));
            tmp[5] = 0;
        }
        ret += retmin;
    }
    dp[tmp[0]][tmp[1]][tmp[2]][tmp[3]][tmp[4]][tmp[5]] = ++ret;
    return ret;
}

int main ()
{
    int T;
    int ncase = 1;
    RI(T);
    while (T--)
    {
        int xx = 0, tot = 0;
        goal.clear();
        REP(i, 4)
        {
            int x; RI(x); goal.push_back(x);
            if (x > 13) xx += x - 13;
            tot += x;
        }
        double ans;
        if (tot > 54 || xx > 2)
        {
            ans = -1.000;
        }
        else
        {
            memset(vis, 0, sizeof(vis));
            vector<int> tmp;
            REP(i, 6) tmp.push_back(0);
            ans = dfs(tmp);
            if (ans < 0) ans = 0.000;
        }
        printf("Case %d: %.3lf\n", ncase++, ans);
    }
    return 0;
}