首页 > 代码库 > 树的点分治 (poj 1741, 1655(树形dp))
树的点分治 (poj 1741, 1655(树形dp))
poj 1655:http://poj.org/problem?id=1655
题意: 给无根树, 找出以一节点为根, 使节点最多的树,节点最少。
题解:一道树形dp,先dfs 标记 所有节点的子树的节点数。 再dfs 找出以某节点为根的最大子树,节点最少。 复杂度(n)
/***Good Luck***/#define _CRT_SECURE_NO_WARNINGS#include <iostream>#include <cstdio>#include <cstdlib>#include <cstring>#include <string>#include <algorithm>#include <stack>#include <map>#include <queue>#include <vector>#include <set>#include <functional>#include <cmath>#define Zero(a) memset(a, 0, sizeof(a))#define Neg(a) memset(a, -1, sizeof(a))#define All(a) a.begin(), a.end()#define PB push_back#define inf 0x7fffffff#define inf2 0x7fffffffffffffff#define ll long longusing namespace std;const int maxn = 100000;int n, k, e, head[maxn];int mx[maxn], sum[maxn], ansn, ansb;struct Node { int next, v;}node[maxn];void input(int u, int v) { node[e].next = head[u]; node[e].v = v; head[u] = e++;}int dfssize(int b, int fa) { sum[b] = 1; mx[b] = 0; int tmpmx; for (int i = head[b]; ~i; i = node[i].next) { int v = node[i].v; if (fa != v) { tmpmx = dfssize(v, b); sum[b] +=tmpmx; if (tmpmx > mx[b]) mx[b] = tmpmx; } } return sum[b];}void solve(int b, int fa) { int tmpmx; tmpmx = max(mx[b], n - sum[b]); if (tmpmx <= ansb) { if (tmpmx < ansb) { ansn = b; ansb = tmpmx; } else if (b < ansn) { ansn = b; ansb = tmpmx; } } for (int i = head[b]; ~i; i = node[i].next) { int v = node[i].v; if (fa != v) { solve(v, b); } }}int main() { int u, v; int T; scanf("%d", &T); while (T-- ) { scanf("%d", &n); e = 0; Neg(head); for (int i = 0; i < n - 1; ++i) { scanf("%d%d", &u, &v); input(u, v); input(v, u); } dfssize(1, 0); ansb = inf; solve(1, 0); printf("%d %d\n", ansn, ansb); } return 0;}
poj 1741:http://poj.org/problem?id=1741
题意:给一值k,在带权无向图G中, 找出两节点相距不大于k的数。
qzc论文的第一题(膜拜q神 orz),根据论文写的 代码, 先写了一题树形dp(1655),再开始写这的,搞了一晚上具体的还是看论文吧。
找根(n), 计算(logn), 一共执行 logn次 总复杂度(n*logn*logn)
1 /***Good Luck***/ 2 #define _CRT_SECURE_NO_WARNINGS 3 #include <iostream> 4 #include <cstdio> 5 #include <cstdlib> 6 #include <cstring> 7 #include <string> 8 #include <algorithm> 9 #include <stack> 10 #include <map> 11 #include <queue> 12 #include <vector> 13 #include <set> 14 #include <functional> 15 #include <cmath> 16 17 #define Zero(a) memset(a, 0, sizeof(a)) 18 #define Neg(a) memset(a, -1, sizeof(a)) 19 #define All(a) a.begin(), a.end() 20 #define PB push_back 21 #define inf 0x3f3f3f3f 22 #define inf2 0x7fffffffffffffff 23 #define ll long long 24 using namespace std; 25 const int maxn = 20000; 26 int head[maxn], n, k, e; 27 int ans, sum[maxn], mx[maxn]; 28 bool vis[maxn]; 29 int dis[maxn], a[maxn], an; 30 struct Node { 31 int w; 32 int v, next; 33 }edge[maxn]; 34 35 void init() { 36 e = 0; 37 ans = 0; 38 Neg(head); 39 Zero(vis); 40 } 41 42 void add(int u, int v, int w) { //邻接表储存 43 edge[e].v = v; 44 edge[e].w = w; 45 edge[e].next = head[u]; 46 head[u] = e++; 47 } 48 49 int dfssize(int u, int fa) { //标记子树的节点数 50 sum[u] = 1; 51 mx[u] = 0; 52 int tmpmx; 53 for (int i = head[u]; ~i; i = edge[i].next) { 54 int v = edge[i].v; 55 if (v != fa && !vis[v]) { 56 tmpmx = dfssize(v, u); 57 sum[u] += tmpmx; 58 if (tmpmx > mx[u]) mx[u] = tmpmx; 59 } 60 } 61 return sum[u]; 62 } 63 64 int ansn, mxshu; 65 void find_root(int u, int fa, int nn) { // 找出符合条件的根。 66 int tmpmx = max(mx[u], nn - sum[u]); 67 if (tmpmx < mxshu) { 68 ansn = u; 69 mxshu = tmpmx; 70 } 71 for (int i = head[u]; ~i; i = edge[i].next) { 72 int v = edge[i].v; 73 if (v != fa && !vis[v]) { 74 find_root(v, u, nn); 75 } 76 } 77 } 78 79 void dfsdis(int u, int fa) { 80 a[an++] = dis[u]; 81 for (int i = head[u]; ~i; i = edge[i].next) { 82 int v = edge[i].v; 83 if (fa != v && !vis[v]) { 84 dis[v] = dis[u] + edge[i].w; 85 dfsdis(v, u); 86 } 87 } 88 } 89 90 int cal(int u, int fa, int beg) { // 这个方法太神奇了 复杂度只有 (logn) 91 an = 0; 92 int ret = 0; 93 dis[u] = beg; 94 dfsdis(u, fa); 95 sort(a, a + an); 96 int l = 0, r = an - 1; 97 while (l < r) { 98 if (a[r] + a[l] <= k ) 99 ret += r - l++;100 else101 r--;102 }103 return ret;104 }105 106 void solve(int u) {107 dfssize(u, 0);108 mxshu = inf;109 find_root(u, 0, sum[u]);110 vis[ansn] = true;111 ans += cal(ansn, 0, 0);112 for (int i = head[ansn]; ~i; i = edge[i].next) {113 int v = edge[i].v;114 if (!vis[v]) {115 ans -= cal(v, ansn, edge[i].w);116 solve(v);117 }118 }119 }120 int main() {121 //freopen("data.out", "w", stdout);122 //freopen("data.in", "r", stdin);123 int u, v, w;124 while (scanf("%d%d", &n, &k), n&&k) {125 init();126 for (int i = 0; i < n - 1; ++i) {127 scanf("%d%d%d", &u, &v, &w);128 add(u, v, w);129 add(v, u, w);130 }131 solve(1);132 printf("%d\n", ans);133 }134 return 0;135 }
声明:以上内容来自用户投稿及互联网公开渠道收集整理发布,本网站不拥有所有权,未作人工编辑处理,也不承担相关法律责任,若内容有误或涉及侵权可进行投诉: 投诉/举报 工作人员会在5个工作日内联系你,一经查实,本站将立刻删除涉嫌侵权内容。