首页 > 代码库 > poj 4045 Power Station(初涉树形dp)

poj 4045 Power Station(初涉树形dp)

http://poj.org/problem?id=4045


大致题意:有n个村庄,求将发电站建在哪一个村庄使得花费最少。这是一个无向无环图。简化一下就是求一个节点使它到其他所有节点的距离和最小。


起初一直在向最短路上靠,但因为节点和边数太大,必定TLE。然后无比强大的啸神随便写了两个dfs就过掉了,简直膜拜。赛后搜了搜题解,发现这是道树形dp。sad,真的要好好刷dp了。

大体思路是将这个无向无环图看做一个树,我们就在这个树上进行动态规划。首先先随便拿一个节点看做根节点(假设节点1),计算出它到其他点的最小距离和,那么接下来当它的儿子做根节点的时候,根据父亲节点的值以及他们的关系就可以直接计算出儿子节点做根节点的距离和,具体是除该节点及其子树之外的所有节点的距离都加1,而该节点及其子节点距离都减1。这个距离是随便拿一个根节点dfs计算出来的。

设dp【】表示每个节点做根节点时它到子节点的距离之和,真正用到的是dp[1],son【】表示每个节点的孩子节点数目,包括自身。若已知父亲节点的花费cost[pre],那么当前节点u的花费cost[u] = cost[pre] + n-son[u] -son[u]。


总之,两次dfs,第一次为了求得dp[1]和每个节点的孩子数目son[],第二次是有父亲节点的花费求得当前节点的花费。

#include <stdio.h>
#include <iostream>
#include <map>
#include <stack>
#include <vector>
#include <math.h>
#include <string.h>
#include <queue>
#include <string>
#include <stdlib.h>
#include <algorithm>
#define LL long long
#define _LL __int64
#define eps 1e-8
#define PI acos(-1.0)
using namespace std;

const int maxn = 50000+10;

struct node
{
	int u,v,next;
}edge[maxn*2];

int cnt,head[maxn];
int n,I,R;

int son[maxn];//每个节点的儿子节点数目,这里包含该节点本身方便计算。
LL dp[maxn];//每个根节点的子节点到该节点的距离和
LL cost[maxn];//以每个节点为根的花费

void init()
{
	cnt = 0;
	memset(head,-1,sizeof(head));
}

void add(int u, int v)
{
	edge[cnt] = (struct node){u,v,head[u]};
	head[u] = cnt++;
}

void dfs(int u, int pre)
{
	dp[u] = 0;
	son[u] = 1;
	for(int i = head[u]; i != -1; i = edge[i].next)
	{
		int v = edge[i].v;
		if(v == pre) continue;
		dfs(v,u);

		son[u] += son[v]; //儿子数目
		dp[u] += dp[v] + son[v];//u到u的各个子节点的距离之和。
	}
}

void cal(int u, int pre)
{
    if(u != 1)
        cost[u] = cost[pre]+n-son[u]-son[u];
    for(int i = head[u]; i != -1; i = edge[i].next)
    {
        int v = edge[i].v;
        if(v == pre) continue;
        cal(v,u);
    }

}

int main()
{
	int test;
	scanf("%d",&test);
	while(test--)
	{
		scanf("%d %d %d",&n,&I,&R);
		int u,v;
		init();
		for(int i = 0; i < n-1; i++)
		{
			scanf("%d %d",&u,&v);
			add(u,v);
			add(v,u);
		}
		//假设1为根节点
		dfs(1,0);

        memset(cost, 0, sizeof(cost));
        cost[1] = dp[1];
        cal(1,0);

        LL Min = cost[1];
        for(LL i = 2; i <= n; i++)
            Min = min(Min,cost[i]);

        cout << Min*I*I*R << endl;

        bool flag = false;
        for(LL i = 1; i <= n; i++)
        {
            if(cost[i] == Min)
            {
                if(flag == false)
                {
                    flag = true;
                    cout << i;
                }
                else cout << " " << i;
            }
        }
        cout << endl << endl;
	}
	return 0;
}