首页 > 代码库 > 51nod_1677:treecnt

51nod_1677:treecnt

题目是求一棵n节点树中对于C(n,k)颗子树,每棵子树为在n个节点中选不同的k个节点作为树的边界点,这样的所有子树共包含多少条边。

问题可以转化一下,对每一条边,不同的子树中可能包含可能不包含这条边,显然,只有子树那k个节点在该边的两侧均有分布时该边才被包含在子树中。所有边的被包含次数的和,即为answer。对于一条边的被包含次数,设该边两侧分别有a,b个节点,那么,该边被包含的次数为C(a+b,k)-C(a,k)-C(b,k)(也可以借助母函数函数求C(a,i)*C(b,k-i),i从1到min{a,b,k-1},结果一样)。

//dfs写的太搓了,调了半天才好。。。

题目链接: https://www.51nod.com/contest/problem.html#!problemId=1677

 1 #include<bits/stdc++.h>
 2 using namespace std;
 3 
 4 typedef long long LL;
 5 const LL mod=1e9+7;
 6 const LL M=1e5+3;
 7 
 8 LL fac[100005];            //阶乘
 9 LL inv_of_fac[100005];        //阶乘的逆元
10 
11 LL qpow(LL x,LL n)
12 {
13     LL ret=1;
14     for(; n; n>>=1)
15     {
16         if(n&1) ret=ret*x%mod;
17         x=x*x%mod;
18     }
19     return ret;
20 }
21 void init()
22 {
23     fac[1]=1;
24     for(int i=2; i<=M; i++)
25         fac[i]=fac[i-1]*i%mod;
26     inv_of_fac[M]=qpow(fac[M],mod-2);
27     for(int i=M-1; i>=0; i--)
28         inv_of_fac[i]=inv_of_fac[i+1]*(i+1)%mod;
29 }
30 LL C(LL a,LL b)
31 {
32     if(b>a) return 0;
33     if(b==0) return 1;
34     return fac[a]*inv_of_fac[b]%mod*inv_of_fac[a-b]%mod;
35 }
36 /////////////////////////////////////////////////////////////
37 vector<int> adj[M];
38 int vis[M];
39 LL n,k,ans,du[M],hh;
40 void init1()
41 {
42     ans=0;
43     memset(vis,0,sizeof(vis));
44     memset(du,0,sizeof(du));
45     du[1]=n;
46     hh=C(n,k);
47     for(int i=1; i<=n; i++)
48         adj[i].clear();
49 }
50 LL dfs(int s)
51 {
52     if(adj[s].size()==1&&s!=1) return du[s]=1;
53     if(du[s]&&s!=1)    return du[s];
54     vis[s]=1;
55     LL ret,cnt=0;
56     for(int i=0; i<adj[s].size(); i++)
57     {
58         if(!vis[adj[s][i]])
59         {
60 //            printf("%d -> %d\n",s,adj[s][i]);
61             cnt+=dfs(adj[s][i]);
62             ans=(ans+hh-C(dfs(adj[s][i]),k)-C(n-dfs(adj[s][i]),k))%mod;
63         }
64     }
65     return du[s]=cnt+1;
66 }
67 
68 int main()
69 {
70     init();
71     while(~scanf("%lld%lld",&n,&k))
72     {
73         init1();
74         for(int i=1; i<n; i++)
75         {
76             LL u,v;
77             scanf("%d%d",&u,&v);
78             adj[u].push_back(v);
79             adj[v].push_back(u);
80         }
81         dfs(1);
82 //        for(int i=1; i<=n; i++)
83 //            printf("%d:%lld=========\n",i,du[i]);
84 //        for(int i=1; i<=n; i++)
85 //        {
86 //            printf("i=%d:\n",i);
87 //            for(int j=0; j<adj[i].size(); j++)
88 //                printf("%d ",adj[i][j]);
89 //            puts("");
90 //        }
91         printf("%lld\n",(ans+mod)%mod);
92     }
93 }

 

51nod_1677:treecnt