首页 > 代码库 > poj 3156 hash+记忆化搜索 期望dp

poj 3156 hash+记忆化搜索 期望dp

#include<iostream>#include<cstdio>#include<algorithm>#include<cstring>using namespace std;int n,m;#define N 32#define mod 10007LL#define mod2 10000007#define inf 0x3ffffffftypedef long long ll;typedef double dd;int f[N];int get(int x){    return f[x]==x?x:f[x]=get(f[x]);}int g[N];int cmp(int a,int b){    return a>b;}struct P{    int num;    int a[N];    P(){        num=0;        memset(a,0,sizeof(a));    }    void ss(){        sort(a+1,a+31,cmp);        for(int i=1;i<=30;i++)if(a[i])num=i;    }    void out(){        cout<<"num="<<num<<endl;       for(int i=1;i<=num;i++){            cout<<a[i]<<" ";       }       cout<<endl;    }    int geth(){        int b=1;        int ans=0;        for(int i=1;i<=num;i++)(ans+=b*a[i])%=mod,b=(b*30)%mod;        return ans;    }    int geth2(){        ll b=1;        ll ans=0;        for(int i=1;i<=num;i++)(ans+=b*a[i])%=mod2,b=(b*30)%mod2;            return ans;    }};int e[mod+2],nn,v[mod*10],ne[mod*10];dd w[mod*10];void add(int x,dd ww,int vv){    ne[++nn]=e[x],e[x]=nn,v[nn]=vv,w[nn]=ww;}dd cha(P x){    int y=x.geth();    int z=x.geth2();    for(int i=e[y];i;i=ne[i]){        if(v[i]==z)return w[i];    }    return inf;}void ins(P x,dd vv){    int y=x.geth();    int z=x.geth2();    add(y,vv,z);}dd ask(P s){    dd res=0;    dd a=cha(s);    if(a!=inf)return a;        P r=s;    if(s.num==1){        return 0;    }    int tot=0;    for(int i=1;i<=s.num;i++){        tot+=s.a[i]*(s.a[i]-1)/2;    }    for(int i=1;i<s.num;i++){        for(int j=i+1;j<=s.num;j++){            s.a[j]+=s.a[i];            s.a[i]=0;            s.ss();            res+=r.a[i]*r.a[j]*ask(s);            s=r;        }    }    res/=(n*(n-1)/2);    res+=1;    res=res*(n*(n-1)/2)/((n-1)*n/2-tot);    ins(s,res);    return res;}int main(){    while(scanf("%d%d",&n,&m)!=EOF){    nn=0;    memset(e,0,sizeof(e));    for(int i=1;i<=n;i++)f[i]=i;    for(int i=1;i<=m;i++){        int a,b;        scanf("%d%d",&a,&b);        f[get(a)]=get(b);    }    for(int i=1;i<=n;i++)g[get(i)]++;    P st;    for(int i=1;i<=n;i++)st.a[i]=g[i];    st.ss();    printf("%.10lf\n",ask(st));   }}