首页 > 代码库 > 并查集

并查集

int const MAX_N=100000;
int par[MAX_N];//父亲
int rank[MAX_N];//树的高度
//初始化n个元素
void init(int n)
{
    for(int i=0;i<n;i++)
    {
        par[i]=i;
        rank[i]=0;
    }
} 
//查询树的根
int find(int x)
{
    if(par[x]==x)
    {
        return x;
    }
    else
    {
        return par[x]=find(par[x]);
    }
} 
//合并x和y所属的集合
void unite(int x,int y)
{
    x=find(x);
    y=find(y);
    if(x==y) return ;
    if(rank[x]<rank[y])
    {
        par[x]=y;
    } 
    else
    {
        par[y]=x;
        if(rank[x]==rank[y]) rank[x]++;
    }
} 
//判断x和y是否属于同一个集合
bool same(int x,int y)
{
    return find(x)==find(y);
} 

具体问题:

//#define LOCAL
#include<cstdio>
#include<algorithm>
int const MAX_N=100000;
int const MAX_K=10000;
int par[MAX_N];//父亲
int rank[MAX_N];//树的高度
int N,K,T[MAX_K],X[MAX_K],Y[MAX_K];
//初始化n个元素
void init(int n)
{
    for(int i=0;i<n;i++)
    {
        par[i]=i;
        rank[i]=0;
    }
} 
//查询树的根
int find(int x)
{
    if(par[x]==x)
    {
        return x;
    }
    else
    {
        return par[x]=find(par[x]);
    }
} 
//合并x和y所属的集合
void unite(int x,int y)
{
    x=find(x);
    y=find(y);
    if(x==y) return ;
    if(rank[x]<rank[y])
    {
        par[x]=y;
    } 
    else
    {
        par[y]=x;
        if(rank[x]==rank[y]) rank[x]++;
    }
} 
//判断x和y是否属于同一个集合
bool same(int x,int y)
{
    return find(x)==find(y);
} 
void solve()
{
    init(N*3);
    
    int ans=0;
    for(int i=0;i<K;i++)
    {
        int t=T[i];
        int x=X[i]-1,y=Y[i]-1;
        
    //不正确的编号
        if(x<0||N<=x||y<0||N<=y)
        {
            ans++;
            continue;
        } 
        
        if(t==1)
        {
            if(same(x,y+N)||same(x,y+2*N))
            {
                ans++; 
            }
            else
            {
                unite(x,y);
                unite(x+N,y+N);
                unite(x+N*2,y+N*2);
            }
        }
        else
        {//x吃y的信息
            if(same(x,y)||same(x,y+2*N))
            {
                ans++;
            } 
            else
            {
                unite(x,y+N);
                unite(x+N,y+2*N);
                unite(x+2*N,y);
            }
        }
    }
    printf("%d\n",ans);
}
int main()
{
#ifdef LOCAL
    freopen("207.in","r",stdin);
    freopen("207.out","w",stdout);
#endif
    scanf("%d%d",&N,&K);
    for(int i=0;i<K;i++)
    {
        scanf("%d%d%d",&T[i],&X[i],&Y[i]);
    }
    solve();
    return 0;
}