首页 > 代码库 > lct模板

lct模板

#include<iostream>
#include<cstring>
#include<algorithm>
#include<string>
#include<cstdio>

using namespace std;

#define MAXN 101000

int sz;
int ch[MAXN][2],f[MAXN];
int rev[MAXN],col[MAXN],node[MAXN];
int lc[MAXN],rc[MAXN],cnt[MAXN];


void newnode(int keys)
{
    int p=++sz;
    cnt[p]=1;
    lc[p]=rc[p]=node[p]=keys;
    rev[p]=ch[p][0]=ch[p][1]=f[p]=0;
    col[p]=-1;
}

 bool isroot(int x)
{
    return (!f[x]||ch[f[x]][0]!=x&&ch[f[x]][1]!=x);
}

void color(int rt,int w)
{
    if(!rt) return ;
    node[rt]=col[rt]=lc[rt]=rc[rt]=w;
    cnt[rt]=1;
}

void reverse(int rt)
{
    if(!rt) return;
    swap(ch[rt][0],ch[rt][1]);
    swap(lc[rt],rc[rt]);
    rev[rt]^=1;
}

void push_up(int rt)
{
    lc[rt]=ch[rt][0]!=0?lc[ch[rt][0]]:node[rt];
    rc[rt]=ch[rt][1]!=0?rc[ch[rt][1]]:node[rt];
    cnt[rt]=1;
    if(ch[rt][0])
    {
        cnt[rt]+=cnt[ch[rt][0]];
        if(node[rt]==rc[ch[rt][0]]) cnt[rt]--;
    }
    if(ch[rt][1])
    {
        cnt[rt]+=cnt[ch[rt][1]];
        if(node[rt]==lc[ch[rt][1]]) cnt[rt]--;
    }
}

void push_down(int rt)
{
    if(!rt) return;
    if(rev[rt])
    {
        reverse(ch[rt][0]);
        reverse(ch[rt][1]);
        rev[rt]=0;
    }
    if(col[rt]!=-1)
    {
        color(ch[rt][0],col[rt]);
        color(ch[rt][1],col[rt]);
        col[rt]=-1;
    }
}

void rotate(int x,int c)
{
    if(isroot(x)) return;
    int y=f[x],z=f[y];
    ch[y][!c]=ch[x][c];
    if(ch[x][c]) f[ch[x][c]]=y;
    ch[x][c]=y;
    if(f[y])
    {
        if(ch[f[y]][0]==y)
            ch[f[y]][0]=x;
        if(ch[f[y]][1]==y)
            ch[f[y]][1]=x;
    }
    f[y]=x;f[x]=z;
    push_up(y);
}

void splay(int x)
{
    int y,z;
    push_down(x);
    while(!isroot(x))
    {
        y=f[x],z=f[y];
        if(isroot(y))
        {
            push_down(y);push_down(x);
            if(ch[y][0]==x)
                rotate(x,1);
            else
                rotate(x,0);
        }
        else
        {
            push_down(z);push_down(y);push_down(x);
            if(ch[z][0]==y)
            {
                if(ch[y][0]==x)
                    rotate(y,1),rotate(x,1);
                else
                    rotate(x,0),rotate(x,1);
            }
            else
            {
                if(ch[y][1]==x)
                    rotate(y,0),rotate(x,0);
                else
                    rotate(x,1),rotate(x,0);
            }
        }
    }
    push_up(x);
}

int Access(int u)
{
    int v=0;
    for(;u;u=f[u])
    {
        splay(u);
        ch[u][1]=v;
        push_up(v=u);
    }
    return v;
}

void makeRoot(int u)
{
    int p=Access(u);
    reverse(p);
    splay(u);
}

int findRoot(int x) {
    for (x = Access(x); push_down(x),ch[x][0]!=0 ; x = ch[x][0]);
    splay(x);
    return x;
}
bool link(int u,int v)
{
    if(findRoot(u)==findRoot(v)) return false;
    makeRoot(u);
    f[u]=v;
    Access(u);
    return true;
}

bool cut(int u,int v)
{
    if(u==v||findRoot(u)!=findRoot(v)) return false;
    makeRoot(u);
    Access(v);
    splay(v);
    f[ch[v][0]]=0;
    ch[v][0]=0;
    push_up(v);
    return true;
}

bool modify(int u,int v,int d)
{
    if (findRoot(u)!=findRoot(v)) return false;
    makeRoot(u);
    Access(v);
    splay(v);
    color(v,d);
    return true;
}

bool query(int u,int v,int &ans)
{
    if (findRoot(u)!=findRoot(v)) return false;
    makeRoot(u);
    Access(v);
    splay(v);
    ans=cnt[v];
    return true;
}

int n,m;

int main()
{
    while(~scanf("%d%d",&n,&m))
    {
        sz=0;
        for(int i=1;i<=n;i++)
        {
            int c;
            scanf("%d",&c);
            newnode(c);
        }
        for(int i=0;i<n-1;i++)
        {
            int u,v;
            scanf("%d%d",&u,&v);
            link(u,v);
        }
        char str[4];
        int a,b,c;
        for(int i=0;i<m;i++)
        {
            scanf("%s",str);
            if(str[0]==C)
            {
                scanf("%d%d%d",&a,&b,&c);
                modify(a,b,c);
            }
            else
            {
                scanf("%d%d",&a,&b);
                query(a,b,c);
                printf("%d\n",c);
            }
        }
    }
    return 0;
}