首页 > 代码库 > bzoj 2815 灾难

bzoj 2815 灾难

  首先假设我们定义x灭绝后y会灭绝,那么离y最近的x就为y的父亲节点,那么如果我们可以求出每个节点的父亲节点,我们就得到了一棵树,然后每个节点的灾难值就是子树的大小-1。

  我们将出度数为0的节点的父亲节点定义为0,那么我们可以发现,某个点的父亲节点就是他所有儿子的父亲节点的lca。

  备注:lca写错了,查了半天。

//By BLADEVIL
#include <cstdio>
#include <cstring>
#include <algorithm>
#define maxn 100010
#define maxm 2000100

using namespace std;

int n;
int pre[maxm][3],last[maxn][3],other[maxm][3],cnt[maxn],l[3];
int que[maxn],dep[maxn],size[maxn];
int jump[maxn][20];

int lca(int x,int y) {
    //if ((!x)||(!y)) return x+y;
    //printf("fuck %d %d\n",x,y);
    if (dep[x]>dep[y]) swap(x,y);
    int det(dep[y]-dep[x]);
    for (int j=0;j<=18;j++) if (det&(1<<j)) y=jump[y][j];
    //printf("%d\n",y);
    if (x==y) return x;
    for (int j=18;j>=0;j--) if (jump[x][j]!=jump[y][j]) x=jump[x][j],y=jump[y][j];
    return jump[x][0];
}

void connect(int x,int y,int cur) {
    if (((!y)||(!x))&&(cur!=2)) return ;
    pre[++l[cur]][cur]=last[x][cur];
    last[x][cur]=l[cur];
    other[l[cur]][cur]=y;
    if (!cur) cnt[x]++;
    //if (cur==2) printf("|%d %d\n",x,y);
}

void dfs(int x) {
    size[x]=1;
    for (int p=last[x][2];p;p=pre[p][2]) {
        dfs(other[p][2]);
        size[x]+=size[other[p][2]];
    }
}


int main() {
    scanf("%d",&n);
    for (int i=1;i<=n;i++) {
        int x(1); 
        while (x) scanf("%d",&x),connect(i,x,0),connect(x,i,1);
    }
    int h(0),t(0);
    //for (int i=1;i<=n;i++) printf("%d ",cnt[i]); printf("\n");
    for (int i=1;i<=n;i++) if (!cnt[i]) que[++t]=i;
    while (h<t) {
        int cur(que[++h]);
        for (int p=last[cur][1];p;p=pre[p][1]) {
            cnt[other[p][1]]--;
            if (!cnt[other[p][1]]) {
                que[++t]=other[p][1];
            }
        }
    }
    //for (int i=1;i<=n;i++) printf("%d %d %d\n",i,que[i],dep[i]); printf("\n");
    memset(dep,0,sizeof dep);
    dep[0]=1;
    for (int i=1;i<=n;i++) {
        int cur(que[i]);
        jump[cur][0]=other[last[cur][0]][0];
        for (int p=pre[last[cur][0]][0];p;p=pre[p][0]) jump[cur][0]=lca(jump[cur][0],other[p][0]);
        //printf("|%d %d\n",cur,jump[cur][0]);
        connect(jump[cur][0],cur,2); dep[cur]=dep[jump[cur][0]]+1;
        for (int j=1;j<=18;j++) jump[cur][j]=jump[jump[cur][j-1]][j-1];
    }
    //printf("%d\n",lca(0,5));
    //for (int i=1;i<=n;i++) printf("%d ",jump[i][0]); printf("\n");
    dfs(0);
    for (int i=1;i<=n;i++) printf("%d\n",size[i]-1);    
    return 0;
}