首页 > 代码库 > HDU 4879 ZCC loves march(并查集+set)

HDU 4879 ZCC loves march(并查集+set)

题意:一个最大10^18*10^18的矩阵,给你最多十万个士兵的位置,分别分布在矩阵里,可能会位置重复,然后有2种操作,一种是把第i个士兵向上下左右移动,另一种是把第i个士兵与他横坐标纵坐标相同的士兵全部移到这个点上,然后要计算花费。

这道题我想了好几天。在看了标程得到一些提示后总算写出来了。加了读入优化后快了100ms左右达到546ms。

做法:开2个set分别维护X相同的和Y相同的,但是会有相同位置点的坐标,该怎么办?用并查集去维护相同位置的点,读入的时候就可能会有位置相同的点,所以读的时候就要用,但是又来一个问题,因为第一步操作要移动点,去移动并查集的不是根的节点很容易,直接就可以移动,但是如果要移动那个并查集的根呢?这个问题我百思不得其解。。看了标程中开了一个new数组,然后貌似是往后加。。顿时想通了,原来的点不删除,如果要移动,直接去新加一个点,然后把新点的坐标保存在new里面,同时还要开一个nct数组去保存并查集中有几个节点。如果你需要移动那个并查集的根,新加一个节点后,那个原来的点就没什么意义了,但是在第二个操作合并2个集合的时候这个根还是有用的。还有一点如果你在移动的操作下如果要移动这个节点的根的nct值是1,那么就说明这个集合只有一个有用的点了,虽然可能还有很多其他点。然后要把这个根从set中删除。注意在这里不能简单的判断该节点的父亲是否是自己,因为可能你就是一个并查集的根,你不能确定下面是否还有节点,所以只能从nct数组保存的节点个数来判断。

总体说实现起来真的挺困难。。

AC代码(加了读入优化):

#include<cstdio>
#include<ctype.h>
#include<algorithm>
#include<iostream>
#include<cstring>
#include<vector>
#include<stack>
#include<cmath>
#include<queue>
#include<set>
#include<ctime>
using namespace std;
#define ll __int64
#define MOD 1000000007
const ll Max = 1e18+1;

inline void scan_d(ll &ret) {
    char c; ret=0;
    while((c=getchar())<'0'||c>'9');
    while(c>='0'&&c<='9') ret=ret*10+(c-'0'),c=getchar();
}

struct node
{
    ll x;
    ll y;
    int pp;
    bool operator!=(const node &n1) const
    {
        if(n1.x == x&&n1.y == y) return 0;
        return 1;
    }
};

struct cmpx
{
    bool operator()(const node &n1, const node &n2) const
    {
        if(n1.x == n2.x) return n1.y<n2.y;
        return n1.x<n2.x;
    }
};

struct cmpy
{
    bool operator()(const node &n1, const node &n2) const
    {
        if(n1.y == n2.y) return n1.x<n2.x;
        return n1.y<n2.y;
    }
};

set<node,cmpx>s1;
set<node,cmpy>s2;
int fa[200010];
int nct[200010],nnew[100005];
node po[200010];
int findit(int x)
{
    return fa[x]!=x ? fa[x] = findit(fa[x]):x;
}

void mergeit(int x, int y)
{
    x = findit(x);
    y = findit(y);
    nct[y] += nct[x];
    fa[x] = y;
}

node make(ll x, ll y, int pp)
{
    node t;t.x = x;t.y=y;t.pp=pp;return t;
}
int main()
{
    //freopen("input.txt","r",stdin);
//    freopen("o.txt","w",stdout);
    ll m;
    int n,i,j,t;
    while(~scanf("%d%I64d",&n,&m))
    {
        s1.clear();
        s2.clear();
        s1.insert(make(-1,0,0));
        s1.insert(make(Max,0,0));
        s2.insert(make(0,-1,0));
        s2.insert(make(0,Max,0));
        set<node,cmpx>::iterator itx1,itx2,tmp1;
        set<node,cmpy>::iterator ity1,ity2,tmp2;
        for(i = 1; i <= n; i++)
        {
            fa[i] = i;
            scan_d(po[i].x);scan_d(po[i].y);
//            scanf("%I64d%I64d",&po[i].x,&po[i].y);
            nnew[i] = i;
            po[i].pp = i;
            itx1 = s1.find(po[i]);
            if(itx1 != s1.end())
            {
                nct[i] = 0;
                mergeit(i,itx1->pp);
                nct[itx1->pp]++;
            }
            else
            {
                nct[i] = 1;
                s1.insert(po[i]);
                s2.insert(po[i]);
            }
        }
        char temp;
        scanf("%d\n",&t);
        ll ans = 0;
        while(t--)
        {
            scanf("%c",&temp);
            ll a,b;
            if(temp == 'Q')
            {
                scan_d(a);
//                scanf("%I64d",&a);
                a ^= ans;
                a = nnew[a];
                ans = 0;
                int x = findit(a);
                tmp1 = itx1 = itx2 = s1.lower_bound(po[x]);
                for(;itx1->x == po[x].x;itx1--);
                itx1++;
                for(;itx2->x == po[x].x;itx2++);
                while(itx1!=itx2)
                {
                    if(tmp1!=itx1)
                    {
                        int pos = itx1->pp;
                        ans += ((ll)nct[pos]*(((itx1->y-po[x].y)%MOD*((itx1->y-po[x].y)%MOD))%MOD))%MOD;
                        ans %= MOD;
                        mergeit(pos,x);
                        s2.erase(make(itx1->x,itx1->y,itx1->pp));
                        s1.erase(itx1++);
                    }
                    else itx1++;
                }
                tmp2 = ity1 = ity2 = s2.lower_bound(po[x]);
                for(;ity1->y == po[x].y;ity1--);
                ity1++;
                for(;ity2->y == po[x].y;ity2++);
                while(ity1!=ity2)
                {
                    if(tmp2!=ity1)
                    {
                        int pos = ity1->pp;
                        ans += ((ll)nct[pos]%MOD*(((ity1->x-po[x].x)%MOD*((ity1->x-po[x].x)%MOD))%MOD))%MOD;
                        ans %= MOD;
                        mergeit(pos,x);
                        s1.erase(make(ity1->x,ity1->y,ity1->pp));
                        s2.erase(ity1++);
                    }
                    else ity1++;
                }
                printf("%I64d\n",ans);
            }
            else
            {
                scan_d(a);scan_d(b);
//                scanf("%I64d%I64d",&a,&b);
                a ^= ans;
                int newa = nnew[a];
                if(temp == 'L')
                {
                    int k = findit(newa);
                    if(nct[k] == 1)
                    {
                        s1.erase(po[k]);
                        s2.erase(po[k]);
                    }
                    nct[k]--;
                    node t = make(po[k].x,po[k].y-b,0);
                    itx1 = s1.find(t);
                    po[++n] = t;
                    po[n].pp = n;
                    nnew[a] = n;
                    if(itx1 == s1.end())
                    {
                        fa[n] = n;
                        s1.insert(po[n]);
                        s2.insert(po[n]);
                        nct[n]=1;
                    }
                    else
                    {
                        nct[n]=0;
                        fa[n] = itx1->pp;
                        nct[itx1->pp]++;
                    }
                }
                if(temp == 'R')
                {
                    int k = findit(newa);
                    if(nct[k] == 1)
                    {
                        s1.erase(po[k]);
                        s2.erase(po[k]);
                    }
                    nct[k]--;
                    node t = make(po[k].x,po[k].y+b,0);
                    itx1 = s1.find(t);
                    po[++n] = t;
                    po[n].pp = n;
                    nnew[a] = n;
                    if(itx1 == s1.end())
                    {
                        fa[n] = n;
                        s1.insert(po[n]);
                        s2.insert(po[n]);
                        nct[n]=1;
                    }
                    else
                    {
                        nct[n]=0;
                        fa[n] = itx1->pp;
                        nct[itx1->pp]++;
                    }
                }
                if(temp == 'U')
                {
                    int k = findit(newa);
                    if(nct[k] == 1)
                    {
                        s1.erase(po[k]);
                        s2.erase(po[k]);
                    }
//                    cout<<k<<endl;
                    nct[k]--;
                    node t = make(po[k].x-b,po[k].y,0);
                    itx1 = s1.find(t);
                    po[++n] = t;
                    po[n].pp = n;
                    nnew[a] = n;
                    if(itx1 == s1.end())
                    {
                        fa[n] = n;
                        s1.insert(po[n]);
                        s2.insert(po[n]);
                        nct[n]=1;
                    }
                    else
                    {
                        nct[n]=0;
                        fa[n] = itx1->pp;
                        nct[itx1->pp]++;
                    }
                }
                if(temp == 'D')
                {
                    int k = findit(newa);
                    if(nct[k] == 1)
                    {
                        s1.erase(po[k]);
                        s2.erase(po[k]);
                    }
//                    cout<<k<<endl;
                    nct[k]--;
                    node t = make(po[k].x+b,po[k].y,0);
                    itx1 = s1.find(t);
                    po[++n] = t;
                    po[n].pp = n;
                    nnew[a] = n;
                    if(itx1 == s1.end())
                    {
                        fa[n] = n;
                        s1.insert(po[n]);
                        s2.insert(po[n]);
                        nct[n]=1;
                    }
                    else
                    {
                        nct[n]=0;
                        fa[n] = itx1->pp;
                        nct[itx1->pp]++;
                    }
                }
            }
        }
    }
    return 0;
}