首页 > 代码库 > tsinsen A1486. 树(王康宁) 边分治+字典树

tsinsen A1486. 树(王康宁) 边分治+字典树

不知为何,这个代码只能得95分

放一下傻逼代码。。。

#include<iostream>
#include<cstdio>
#include<algorithm>
#include<cstring>
using namespace std;
int n,K;
int nn2=1,nn=1,nod;
#define N 1000000
#define ed(x) (x>>1)
#define ab(x) ((x)>0?(x):-(x))
int e2[N],ne2[N],v2[N];
int e[N],ne[N],v[N];
int im[N],val[N];
int num[N];
bool been[N];
void add2(int x,int y){
   ne2[++nn2]=e2[x],e2[x]=nn2,v2[nn2]=y;    
}
void add(int x,int y){
    ne[++nn]=e[x],e[x]=nn,v[nn]=y;
}
int siz[N];
bool vis[N];
int tot;
int wv,u,mi,rt1,rt2;
int bn[2];
int res=-1;
struct P{
    int x,y;
    P(int a=0,int b=0){
        x=a,y=b;
    }
    bool operator<(P a)const{
       return y<a.y;
    }
}b[2][N];
void dfs(int x,int y){
    siz[y]=1;
    for(int i=e[y];i;i=ne[i])if(!vis[i]&&i!=(x^1)){
            dfs(i,v[i]);
            siz[y]+=siz[v[i]];
    }
    u=ab(tot-siz[y]*2);
    if(x)if(wv==-1||u<mi){
        mi=u,wv=x;
    }
}
int yy;
void dfs2(int x,int fa,int xx,int kk){
    for(int i=e[x];i;i=ne[i])if(fa!=v[i]&&!vis[i]){
        dfs2(v[i],x,xx^val[v[i]],kk+im[v[i]]);
    }
    b[yy][++bn[yy]]=P(xx,kk);
}
int p[N][2],dd;
void cl(){
    for(int i=0;i<=dd;i++)p[i][1]=0,p[i][0]=0;
    dd=0;
}
void ins(int x){
    int now=0;
    for(int i=31;i>=1;i--){
        u=(x>>i-1)&1;
        if(!p[now][u])p[now][u]=++dd;
        now=p[now][u];
    }
}
int check(int x){
    if(!p[0][1]&&!p[0][0])return -1;
    int now=0;
    int ans=0;
    for(int i=31;i>=1;i--){
        u=(x>>i-1)&1;
        if(p[now][!u]){
            ans+=(1<<i-1);
            now=p[now][!u];
        }else now=p[now][u];
    }
    return ans;
}
void solv(int x,int wt){
    if(wt<2)return;
    tot=wt;
    wv=-1;
    dfs(0,x);
    vis[wv]=1;
    vis[wv^1]=1;
    int a1=v[wv],a2=v[wv^1];
    yy=0;
    bn[0]=bn[1]=0;
    dfs2(a1,0,val[a1],im[a1]);
    yy=1;
    dfs2(a2,0,val[a2],im[a2]);
    sort(b[1]+1,b[1]+bn[1]+1);
    sort(b[0]+1,b[0]+bn[0]+1);
    cl();
    int s1=bn[0];
    for(int i=1;i<=bn[1];i++){
        while(s1>0&&b[0][s1].y+b[1][i].y>=K){
            ins(b[0][s1].x);
            s1--;
        }
        res=max(res,check(b[1][i].x));
    }
    solv(a1,bn[0]);
    solv(a2,bn[1]);
}
void rebuild(int x){
    been[x]=1;
    if(num[x]>4){
        int n1=0;
        int u1=++nod;
        int u2=++nod;
        num[u1]=1;
        num[u2]=1;
        add(x,u1);
        add(u1,x);
        add(x,u2);
        add(u2,x);
        for(int i=e2[x];i;i=ne2[i])if(!been[v2[i]]){
            n1++;
            if(n1<=((num[x]-1)/2)){
                add2(u1,v2[i]);
                num[u1]++;
            }else{
                num[u2]++;
                add2(u2,v2[i]);
            }
        }
        rebuild(u1);
        rebuild(u2);
    }else{
       for(int i=e2[x];i;i=ne2[i])if(!been[v2[i]]){
            add(x,v2[i]);
            add(v2[i],x);
            rebuild(v2[i]);
       }
    }
}
int main(){
    scanf("%d%d",&n,&K);
    for(int i=1;i<=n;i++){
        scanf("%d",&im[i]);
    }
    for(int i=1;i<=n;i++){
        scanf("%d",&val[i]);
    }
    for(int i=1;i<n;i++){
        int a,b;
        scanf("%d%d",&a,&b);
        add2(a,b);
        add2(b,a);  
        num[a]++;
        num[b]++;
    }
    for(int i=1;i<=n;i++)if(im[i]>=K){
           res=max(res,val[i]);
        }
    nod=n;
    rebuild(1);

    solv(1,nod);
        
    cout<<res;
    
}