首页 > 代码库 > Codeforces 771C

Codeforces 771C

我的树形dp果然是渣。。。

题意:给一棵树,共n(0<n<=15e4)个节点,可在树上进行跳跃,每次跳的最大距离为k(0<k<=5),定义f(s,t)为(dis(s,t)+k)/k,问Σf(s,t),s<t。

解题思路:

  显然是树形dp,问题在于怎么构建状态。

  最简单想到的就是,每到一个节点u,记录其子树中与其距离为d的的节点的数目,即(dis,cnt)对,则答案分两种情况,u到其子节点和以及子节点经过u到子节点,问题变得很简单,计算也不难——但问题在于极端情况下——比如树退化成链,时空复杂度都将高到无法忍受,于是卡在了这里……

  然后看了别人的题解,发现不需要构建(dis,cnt)对,而是构建(dis%k,cnt)对——这样复杂度枚举的最高复杂度也就只是k^2而不是dis^2,,,啊感觉自己宛若一个zz。

  定义,sz(u,i)表示与节点u的距离%k为 i 的节点数目,dp(u,i)表示从u到sz(u,i)中节点的跳数和

  状态转移如下:

  (1)sz(u,(i+1)%k)+=sz(v,i)

  (2)dp(u,(i+1)%k)+=dp(v,i)  0<i<k  

  (3)dp(u,1%k)+=dp(v,0)+sz(v,0)

  显然,与u的子节点v的距离%k大于0的,到v所需跳数与到u所需跳数是相同的(式子(2))。否则跳数需要+1,有sz(v,0)个点,故再加上sz(v,0)(式子(3))。

  接下来是统计答案,分两种情况,一个是u到其子节点,另一个是u到子节点到其它子节点(子节点间的最近公共祖先节点为u)。

  对于第一种,直接res+=Σdp[u][i]即可。

  对于第二种,将之前已经遍历过的子树节点都合并到dp[u][i]与sz[u][i]中,则与新的子树节点v合并时,枚举k1、k2,依次为已遍历过的子树节点与u的距离%k=k1,以及与v的子树中与u的距离%k为k2的节点(语文不好这段贼拗口……k^2的枚举就是这里了),分情况讨论:

  如果k1为0,则dp[u][k1]中则为恰好到达u的跳数(即每一跳距离都是k),则从sz[u][k1]中的节点到达 v的子树中的节点 所需跳数 恰好为二者相加之和,不需额外处理;如果k2为0,同理;当k1与k2都不为0时,考虑k1+k2<=k,则此时说明多计算了一跳,因此需要减去;当k1+k2>k时,恰好符合。

  综上,需要特判的也只有(k1&&k2&&k1+k2<=k)这种情况。

  代码如下:

#include<iostream>
#include<cstdio>
#include<algorithm>
#include<cstring>
using namespace std;
typedef long long ll;
#define sqr(x) ((x)*(x))
const int N=2e5+10;
int head[N],nxt[N<<1],to[N<<1],cnt;
int n,k,a,b;
ll sz[N][6],dp[N][6],res;
void init(){
    memset(head,-1,sizeof(head));
    res=cnt=0;
    memset(sz,0,sizeof(sz));
    memset(dp,0,sizeof(dp));
}
void addEdge(int u,int v){
    nxt[cnt]=head[u];
    to[cnt]=v;
    head[u]=cnt++;
}
void dfs(int u,int pre){
    ll tsz[6],tdp[6];  //用以暂时保存从u到新的v子树节点的数据
    for(int e=head[u];~e;e=nxt[e]){
        int v=to[e];
        if(v==pre) continue;
        dfs(v,u);
        memset(tsz,0,sizeof(tsz));
        memset(tdp,0,sizeof(tdp));
        for(int i=0;i<k;i++)
            tsz[(i+1)%k]+=sz[v][i];
        for(int i=1;i<k;i++)
            tdp[(i+1)%k]+=dp[v][i];
        tdp[1%k]+=dp[v][0]+sz[v][0];
        for(int k1=0;k1<k;k1++){
            for(int k2=0;k2<k;k2++){
                res+=dp[u][k1]*tsz[k2]+sz[u][k1]*tdp[k2];
                if(k1&&k2&&k1+k2<=k) res-=sz[u][k1]*tsz[k2];
            }
        }
        //将v的子树节点情况合并到u下
        for(int i=0;i<k;i++)
            dp[u][i]+=tdp[i],sz[u][i]+=tsz[i];
    }
    for(int i=0;i<k;i++)
        res+=dp[u][i];
    sz[u][0]++;
}
int main(){
    //freopen("in.txt","r",stdin);
    while(~scanf("%d%d",&n,&k)){
        init();
        for(int i=1;i<n;i++){
            scanf("%d%d",&a,&b);
            addEdge(a,b);
            addEdge(b,a);
        }
        dfs(1,0);
        printf("%I64d\n",res);
    }
    return 0;
}

  参考题解:http://www.cnblogs.com/AOQNRMGYXLMV/p/6579771.html

Codeforces 771C