首页 > 代码库 > Given a tree, find the node with the minimum sum of distances to other nodes

Given a tree, find the node with the minimum sum of distances to other nodes

O(n) complexity, have a traversal for the tree. Get the information of all children, then traverse the tree again. 


#include <iostream>
#include <vector>
#include <algorithm>
#include <utility>
using namespace std;

class Node {
 public:
  int cnum;
  vector<int> child;
  Node(int num):cnum(num) {
  }
};


int buildTree(const vector<pair<int,int>>& edges,
                   vector<int>& visited, vector<Node>& tree, int cur, int depth, int& sum) {                   
  int size = edges.size(), i;

  for (i = 0; i < size; ++i) 
    if (visited[i] == 0 && (edges[i].first == cur || edges[i].second == cur)) {
      visited[i] = 1;
      int next = edges[i].first == cur ? edges[i].second : edges[i].first;
      sum += depth+1;
      tree[cur].child.push_back(next);
      tree[cur].cnum += buildTree(edges, visited, tree, next,depth+1, sum) + 1;
    }
  return tree[cur].cnum;
}
void dfs(vector<Node>& tree, int root, int& minsum, int& minroot, int cursum) {

  int i, size = tree[root].child.size(), next, n = tree.size(), m;
  for (i = 0; i < size; ++i) {
    next = tree[root].child[i], m = tree[next].cnum;
    int sum = cursum - m + (n - m -2);
    if (sum < minsum) {
      minsum = sum;
      minroot = next;
    }
    dfs(tree, next, minsum, minroot, sum);
  }  
}
int getNode(const vector<pair<int, int>>& edges) {
  int size = edges.size(), root, sum = 0, minroot, minsum, cnum;
  if (size == 0)
    return 0;
  vector<Node> tree(size+1, Node(0));    
  vector<int> edgevisited(size, 0);
  
  root= edges[0].first, minroot = root;
  cnum = buildTree(edges, edgevisited, tree, root, 0, sum);
  minsum = sum;
  dfs(tree, root, minsum, minroot, sum);
  return minroot;
}
int main() {
  vector<pair<int, int>> edges;
  edges.push_back(make_pair(0,1));
  edges.push_back(make_pair(0,2));
  edges.push_back(make_pair(4,3));
  edges.push_back(make_pair(1,3));
  edges.push_back(make_pair(1,5));
  edges.push_back(make_pair(6,5));
  int res = getNode(edges);
  return 0;
}