首页 > 代码库 > bzoj 4033 树上染色 - 树形动态规划

bzoj 4033 树上染色 - 树形动态规划

  有一棵点数为N的树,树边有边权。给你一个在0~N之内的正整数K,你要在这棵树中选择K个点,将其染成黑
色,并将其他的N-K个点染成白色。将所有点染色后,你会获得黑点两两之间的距离加上白点两两之间距离的和的
收益。问收益最大值是多少。

Input

第一行两个整数N,K。
接下来N-1行每行三个正整数fr,to,dis,表示该树中存在一条长度为dis的边(fr,to)。
输入保证所有点之间是联通的。
N<=2000,0<=K<=N

Output

输出一个正整数,表示收益的最大值。

Sample Input

5 21 2 31 5 12 3 12 4 2

Sample Output

17【样例解释】将点1,2染黑就能获得最大收益。

  动态规划的第一步——设计状态,f[i][j]表示以i节点为根的子树中染了j个黑点的"收益"。

  不过这样没有黑点的位置,这么多个点,总不可能用N进制来表示点的位置。所以只能换个思路。

  对于当前考虑的这棵子树,我知道染了j个节点,那么我知道在这棵子树内的白点数和子树外的白点数和黑点数。因此我可以计算出节点i到它的父节点的那条边的对答案的贡献,对于子节点转移到父节点就是一个用dp合并的过程,因此解决了状态转移的问题,时间复杂度为O(nk)。

  注意dp时不合法的状态一定不能转移(看代码吧,或者自己想想也可以,状态转移前有个if)

Code

  1 /**  2  * bzoj  3  * Problem#4033  4  * Accepted  5  * Time:630ms  6  * Memory:17092k  7  */  8 #include<iostream>  9 #include<fstream>  10 #include<sstream> 11 #include<algorithm> 12 #include<cstdio> 13 #include<cstring> 14 #include<cstdlib> 15 #include<cctype> 16 #include<cmath> 17 #include<ctime> 18 #include<map> 19 #include<stack> 20 #include<set> 21 #include<queue> 22 #include<vector> 23 #ifndef WIN32 24 #define AUTO "%lld" 25 #else 26 #define AUTO "%I64d" 27 #endif 28 using namespace std; 29 typedef bool boolean; 30 #define inf 0xfffffff 31 #define smin(a, b) (a) = min((a), (b)) 32 #define smax(a, b) (a) = max((a), (b)) 33 template<typename T> 34 inline boolean readInteger(T& u) { 35     char x; 36     int aFlag = 1; 37     while(!isdigit((x = getchar())) && x != - && x != -1); 38     if(x == -1)    { 39         ungetc(x, stdin); 40         return false; 41     } 42     if(x == -) { 43         aFlag = -1; 44         x = getchar(); 45     } 46     for(u = x - 0; isdigit((x = getchar())); u = u * 10 + x - 0); 47     u *= aFlag; 48     ungetc(x, stdin); 49     return true; 50 } 51  52 ///map template starts 53 typedef class Edge{ 54     public: 55         int end; 56         int next; 57         int w; 58         Edge(const int end = 0, const int next = 0, const int w = 0):end(end), next(next), w(w){} 59 }Edge; 60  61 typedef class MapManager{ 62     public: 63         int ce; 64         int *h; 65         Edge *edge; 66         MapManager(){} 67         MapManager(int points, int limit):ce(0){ 68             h = new int[(const int)(points + 1)]; 69             edge = new Edge[(const int)(limit + 1)]; 70             memset(h, 0, sizeof(int) * (points + 1)); 71         } 72         inline void addEdge(int from, int end, int w){ 73             edge[++ce] = Edge(end, h[from], w); 74             h[from] = ce; 75         } 76         inline void addDoubleEdge(int from, int end, int w){ 77             addEdge(from, end, w); 78             addEdge(end, from, w); 79         } 80         Edge& operator [] (int pos) { 81             return edge[pos]; 82         } 83 }MapManager; 84 #define m_begin(g, i) (g).h[(i)] 85 ///map template ends 86  87 template<typename T>class Matrix{ 88     public: 89         T *p; 90         int lines; 91         int rows; 92         Matrix():p(NULL){    } 93         Matrix(int rows, int lines):lines(lines), rows(rows){ 94             p = new T[(lines * rows)]; 95         } 96         T* operator [](int pos){ 97             return (p + pos * lines); 98         } 99 };100 #define matset(m, i, s) memset((m).p, (i), (s) * (m).lines * (m).rows)101 102 int n, k;103 MapManager g;104 Matrix<long long> f;105 int* size;106 107 inline void init() {108     readInteger(n);109     readInteger(k);110     g = MapManager(n, 2 * n);111     f = Matrix<long long>(n + 1, k + 1);112     size = new int[(const int)(n + 1)];113     matset(f, 0, sizeof(long long));114     for(int i = 1, a, b, c; i < n; i++) {115         readInteger(a);116         readInteger(b);117         readInteger(c);118         g.addDoubleEdge(a, b, c);119     }120 }121 122 void treedp(int node, int fa, int len) {123     size[node] = 1;124     for(int i = m_begin(g, node); i != 0; i = g[i].next) {125         int& e = g[i].end;126         if(e == fa)     continue;127         treedp(e, node, g[i].w);128         size[node] += size[e];129         for(int j = min(size[node], k); j >= 0; j--) {130             for(int s = 0; s <= size[e] && s <= j; s++) {131                 if(j - s <= size[node] - size[e])132                     smax(f[node][j], f[node][j - s] + f[e][s]);133             }134         }135     }136     for(int i = 0; i <= min(size[node], k); i++)137             f[node][i] += (i * 1LL * (k - i) + (size[node] - i) * 1LL * (n - k - size[node] + i)) * len;138 }139 140 inline void solve() {141     treedp(1, 0, 0);142     printf(AUTO"\n", f[1][k]);143 }144 145 int main() {146     init();147     solve();148     return 0;149 }

bzoj 4033 树上染色 - 树形动态规划