首页 > 代码库 > wust-1299-结点选择(树形DP)

wust-1299-结点选择(树形DP)

Problem Description

有一棵 n 个节点的树,树上每个节点都有一个正整数权值。如果一个点被选择了,那么在树上和它相邻的点都不能被选择。求选出的点的权值和最大是多少?

Input

第一行包含一个整数 n 。
接下来的一行包含 n 个正整数,第 i 个正整数代表点 i 的权值。
接下来一共 n-1 行,每行描述树上的一条边。

Output

输出一个整数,代表选出的点的权值和的最大值。

Sample Input

5
1 2 3 4 5
1 2
1 3
2 4
2 5

Sample Output

12

HINT

样例说明

选择3、4、5号点,权值和为 3+4+5 = 12 。


数据规模与约定

对于20%的数据, n <= 20。

对于50%的数据, n <= 1000。

对于100%的数据, n <= 100000。

权值均为不超过1000的正整数。

思路:对于每一个点i,要么选:val[i]+gson[i],其中gson[i]表示所有孙子节点的dp值之和;要么不选:son[i],其中son[i]表示所有儿子节点的dp值之和。题中给出的是无根树,只需要任选一个节点作为根节点(这里选1),就可以确定父子关系了。然后用数组模拟栈来建树,最后再从数组最后一个元素开始往前推(这样就相当于从树的最下面一层开始往上推),并且跟新父节点的son值,和祖父节点的gson值。

#include <cstdio>
using namespace std;

int val[100001],dp[100001],son[100001],gson[100001],first[100001],next[200002],to[200002],que[100001],far[100001];
bool vis[100001];

int main()
{
    int n,i,u,v,ans;

    while(~scanf("%d",&n))
    {
        for(i=1;i<=n;i++) first[i]=-1,vis[i]=son[i]=gson[i]=dp[i]=0;

        for(i=1;i<=n;i++) scanf("%d",&val[i]);

        for(i=0;i<n-1;i++)
        {
            scanf("%d%d",&u,&v);

            to[i*2]=v;
            next[i*2]=first[u];
            first[u]=i*2;

            to[i*2+1]=u;
            next[i*2+1]=first[v];
            first[v]=i*2+1;
        }

        int top=0,bottom=1;

        que[top]=1;
        vis[que[top]]=1;
        far[1]=-1;

        while(top<bottom)//建树
        {
            for(int e=first[que[top]];e!=-1;e=next[e])
            {
                if(!vis[to[e]])
                {
                    vis[to[e]]=1;

                    que[bottom++]=to[e];

                    far[to[e]]=que[top];//记录父亲节点
                }
            }

            top++;
        }

        ans=0;

        for(i=bottom-1;i>=0;i--)//从树的最下面一层开始往上推
        {
            dp[que[i]]=val[que[i]]+gson[que[i]]>son[que[i]]?val[que[i]]+gson[que[i]]:son[que[i]];

            ans=ans>dp[que[i]]?ans:dp[que[i]];

            if(far[que[i]]!=-1)
            {
                son[far[que[i]]]+=dp[que[i]];//更新父节点的son值

                if(far[far[que[i]]]!=-1)
                    gson[far[far[que[i]]]]+=dp[que[i]];//更新祖父节点的gson值
            }
        }

        printf("%d\n",ans);
    }
}