首页 > 代码库 > hiho12周树形dp

hiho12周树形dp

#include<iostream>#include<cstdio>#include<cstring>using namespace std;const int maxn = 1111;int len;int head[maxn];//int father[maxn];struct Node{    int to; int next;}e[maxn * 2];void add(int from, int to){    e[len].to = to;    e[len].next = head[from];    head[from] = len++;}int dp[maxn][2222];/*void build(int root){    for (int i = head[root]; i != -1; i = e[i].next){        int cc = e[i].to;        if (cc == father[root]) continue;        father[cc] = root;        build(cc);    }}*/int val[maxn];int dfs(int x, int m,int father){    if (m == 0) return 0;    for (int i = 1; i <= m; i++)        dp[x][i] = val[x];    for (int i = head[x]; i != -1; i = e[i].next){        int cc = e[i].to;        if(cc==father) continue;        dfs(cc, m - 1,x);        for(int j=m;j>=2;j--){            int Max = dp[x][j];            for (int k = j-1;k>=1; k--){                Max = max(Max, dp[cc][k]+dp[x][j - k]);            }            dp[x][j] = Max;        }    }    return dp[x][m];}int main(){    int n, m;    int a, b;    while (cin >> n >> m){        len = 0;        memset(dp, 0, sizeof(dp));//        memset(father, -1, sizeof(father));        memset(head, -1, sizeof(head));//        father[1] = 1;//        build(1);        for (int i = 1; i <= n; i++)            scanf("%d", &val[i]);        for (int i = 0; i < n - 1; i++){            scanf("%d%d", &a, &b);            add(a, b); add(b, a);        }        cout << dfs(1, m,-1) << endl;    }    return 0;}

 

hiho12周树形dp