首页 > 代码库 > 点分治

点分治

每次找到去掉后剩下最大联通快最小的点,即重心计算贡献

#include <cstdio>
#include <iostream>
#include <algorithm>
using namespace std;

  int next[80001],des[80001],len[80001],nd[40001],cnt,bt[40001],dfstim,size[40001];
  int msiz[40001],b[40001],maxsiz,grav,sta[40001],top,k,fin[40001];
  long long ans;

  void addedge(int x,int y,int le){
      next[++cnt]=nd[x];des[cnt]=y;len[cnt]=le;nd[x]=cnt;
  }
  
  void dfs1(int po){
      bt[po]=dfstim;size[po]=1;msiz[po]=0;
      for (int p=nd[po];p!=-1;p=next[p])
        if (b[des[p]]&&bt[des[p]]<dfstim){
            dfs1(des[p]);
            msiz[po]=max(msiz[po],size[des[p]]);
            size[po]+=size[des[p]];
      }  
  }
  
  void dfs2(int po,int tot){
      bt[po]=dfstim;
      if (max(msiz[po],tot-size[po])<maxsiz){
        maxsiz=max(msiz[po],tot-size[po]);
        grav=po;
      }
      for (int p=nd[po];p!=-1;p=next[p])
        if (b[des[p]]&&bt[des[p]]<dfstim)
            dfs2(des[p],tot);
  }
  
  int findgrav(int po){
      dfstim++;dfs1(po);
      maxsiz=1e9;
      dfstim++;dfs2(po,size[po]);
  }
  
  void dfs3(int po,int left){
      if (left>=0) sta[++top]=k-left;
      if (left<0) return;
      bt[po]=dfstim;
      
      for (int p=nd[po];p!=-1;p=next[p])
        if (b[des[p]]&&bt[des[p]]<dfstim)
          dfs3(des[p],left-len[p]);
  }
  
  void work(int po){
      findgrav(po);
      
      b[grav]=0;
      int all=0;
      for (int p=nd[grav];p!=-1;p=next[p])
      if (b[des[p]]){
        dfstim++;top=0;
        dfs3(des[p],k-len[p]);
        ans+=top;
        
        sort(sta+1,sta+top+1);
        int tmp=all;
        for (int i=1;i<=top;i++){
            while (sta[i]+fin[tmp]>k) tmp--;
            ans+=tmp;
        }
        
        for (int i=1;i<=top;i++) fin[++all]=sta[i];
        sort(fin+1,fin+all+1);
      }
      
      for (int p=nd[grav];p!=-1;p=next[p])
        if (b[des[p]]) 
          work(des[p]);
  }

  int main(){
      int n;
      scanf("%d",&n);
      for (int i=1;i<=n;i++) nd[i]=-1,b[i]=1;
      for (int i=1;i<n;i++){
        int t1,t2,t3;
      scanf("%d%d%d",&t1,&t2,&t3);
      addedge(t1,t2,t3);addedge(t2,t1,t3);
      }
      scanf("%d",&k);
      
      work(1);
      
      printf("%lld\n",ans);
  }

 

点分治