首页 > 代码库 > POJ 2486 Apple Tree

POJ 2486 Apple Tree

题目大意:

一棵树上每一个节点都有几个苹果。问在根节点出发,走不大于K步的情况下最多能取多少个苹果。


解题思路:

树形DP,对于每一个子树的根节点src,都有dp[src][i][0],表示从src走i步可以回到src最多能够得到多少苹果。dp[src][i][1]表示从src走i步没有回到src最多能够得到多少苹果。

状态有三种转移方式:
1、用i-j-2步走其他子树回到根节点再用j步走某一子树再回到根节点。

2、用i-j-2步走其他子树回到根节点再用j步走某一子树没有回到根节点。

3、用j步走某一子树再回到根节点在用i-j-1步走其他子树不回到根节点。


注意状态的初始化。


下面是代码:

#include <stdio.h>
#include <string.h>
#include <algorithm>
#include <iostream>
#include <math.h>
#include <stdlib.h>
#include <vector>
#include <string>
#include <map>
#include <queue>
using namespace std;

int min(int a,int b)
{
    if(a>b)a=b;
    return a;
}
int max(int a,int b)
{
    if(a<b)a=b;
    return a;
}
struct node1
{
    int to,next;
}edge[105*2];
int head[105],cnt,n,k,applenum[105],u,v,dp[105][205][2];
bool vis[105];
void addedge(int u,int v)
{
    edge[cnt].to=v;
    edge[cnt].next=head[u];
    head[u]=cnt++;
    edge[cnt].to=u;
    edge[cnt].next=head[v];
    head[v]=cnt++;
}
bool chack(int src)
{
    int t=head[src];
    while(t!=-1)
    {
        if(!vis[edge[t].to])return true;
        t=edge[t].next;
    }
    return false;
}
void dfs(int src)
{
    if(vis[src])return;
    else vis[src]=true;
    if(chack(src))
    {
        int t=head[src];
        while(t!=-1)
        {
            if(vis[edge[t].to])
            {
                t=edge[t].next;
                continue;
            }
            dfs(edge[t].to);
            for(int i=k;i>0;i--)
            {
                for(int j=0;i-j>=0;j++)
                {
                    if(i-j-2>=0)
                    {
                        dp[src][i][0]=max(dp[src][i][0],dp[src][i-j-2][0]+dp[edge[t].to][j][0]);
                        dp[src][i][1]=max(dp[src][i][1],dp[src][i-j-2][1]+dp[edge[t].to][j][0]);
                    }
                    if(i-j-1>=0)
                    {
                        dp[src][i][1]=max(dp[src][i][1],dp[src][i-j-1][0]+dp[edge[t].to][j][1]);
                    }
                }
            }
            t=edge[t].next;
        }
    }
}
int main()
{
    while(scanf("%d%d",&n,&k)!=EOF)
    {
        cnt=0;
        memset(vis,false,sizeof(vis));
        memset(dp,0,sizeof(dp));
        for(int i=1;i<=n;i++)
        {
            head[i]=-1;
            scanf("%d",&applenum[i]);
        }
        for(int i=1;i<=n;i++)
        {
            for(int j=0;j<=k;j++)
            {
                dp[i][j][0]=applenum[i];
                dp[i][j][1]=applenum[i];
            }
        }
        for(int i=2;i<=n;i++)
        {
            scanf("%d%d",&u,&v);
            addedge(u,v);
        }
        dfs(1);
        printf("%d\n",max(dp[1][k][0],dp[1][k][1]));
    }
    return 0;
}