首页 > 代码库 > BZOJ2243 [SDOI2011]染色(树链剖分+线段树合并)

BZOJ2243 [SDOI2011]染色(树链剖分+线段树合并)

题目链接 BZOJ2243

树链剖分+线段树合并

线段树合并的一些细节需要注意一下

#include <bits/stdc++.h>

using namespace std;

#define rep(i, a, b)	for (int i(a); i <= (b); ++i)
#define dec(i, a, b)	for (int i(a); i >= (b); --i)

typedef long long LL;

const int N = 100010;

int n, m, cnt, sz, head[N], son[N], deep[N];
int belong[N], pl[N], v[N], st[N][20];
bool vis[N];

struct segtree{
	int l, r;
	int lc, rc;
	int s;
	int tag;
} t[N << 2];


struct edge{
	int to;
	int next;
} e[N << 1];

inline void ins(int u, int v){
	e[++cnt].to = v;
	e[cnt].next = head[u];
	head[u] = cnt;
	e[++cnt].to = u;
	e[cnt].next = head[v];
	head[v] = cnt;
}

void dfs1(int x, int fa){
	vis[x] = son[x] = 1;
	rep(i, 1, 17){
		if (deep[x] < (1 << i)) break;
		st[x][i] = st[st[x][i - 1]][i - 1];
	}
	
	for (int i = head[x]; i; i = e[i].next){
		if (e[i].to == fa) continue;
		deep[e[i].to] = deep[x] + 1;
		st[e[i].to][0] = x;
		dfs1(e[i].to, x);
		son[x] += son[e[i].to];
	}
}

void dfs2(int x, int pos){
	pl[x] = ++sz;
	belong[x] = pos;
	int k = 0;
	for (int i = head[x]; i; i = e[i].next)
		if (deep[e[i].to] > deep[x] && son[k] < son[e[i].to])
			k = e[i].to;
	
	if (!k) return;
	
	dfs2(k, pos);
	
	for (int i = head[x]; i; i = e[i].next)
		if (deep[e[i].to] > deep[x] && k != e[i].to)
			dfs2(e[i].to, e[i].to);
}

int LCA(int x, int y){

	if (deep[x] < deep[y]) swap(x, y);
	int t = deep[x] - deep[y];
	rep(i, 0, 17) if (t & (1 << i))  x = st[x][i];
	dec(i, 17, 0) if (st[x][i] != st[y][i])
		x = st[x][i],
		y = st[y][i];
		
	
	if (x == y) return x;
	return st[x][0];
}

void build(int i, int l, int r){
	t[i].l = l;
	t[i].r = r;
	t[i].s = 1;
	t[i].tag = -1;
	
	if (l == r) return;

	int mid = (l + r) >> 1;
	
	build(i << 1, l, mid);
	build(i << 1 | 1, mid + 1, r);
}

void pushup(int k){
	t[k].lc = t[k << 1].lc;
	t[k].rc = t[k << 1 | 1].rc;
	if (t[k << 1].rc ^ t[k << 1 | 1].lc) t[k].s = t[k << 1].s + t[k << 1 | 1].s;
	else t[k].s = t[k << 1].s + t[k << 1 | 1].s - 1;
}

void pushdown(int k){
	int tmp = t[k].tag;
	t[k].tag = -1;
	if (tmp == -1 || t[k].l == t[k].r) return;
	
	t[k << 1].s = t[k << 1 | 1].s = 1;
	t[k << 1].tag = t[k << 1 | 1].tag = tmp;
	
	t[k << 1].lc = t[k << 1].rc = tmp;
	t[k << 1 | 1].lc = t[k << 1 | 1].rc = tmp;
}

void change(int k, int x, int y, int c){
	pushdown(k);
	int l = t[k].l, r = t[k].r;
	if (l == x && r == y){
		t[k].lc = t[k].rc = c;
		t[k].s = 1;
		t[k].tag = c;
		return;
	}
	
	int mid = (l + r) >> 1;
	if (mid >= y) change(k << 1, x, y, c);
	else if (mid < x) change(k << 1 | 1, x, y, c);
	else{
		change(k << 1, x, mid, c);
		change(k << 1 | 1, mid + 1, y, c);
	}
	
	pushup(k);
}

int ask(int k, int x, int y){
	pushdown(k);
	int l = t[k].l, r = t[k].r;
	if (l == x && r == y) return t[k].s;
	
	int mid = (l + r) >> 1;
	if (mid >= y) return ask(k << 1, x, y);
	else if (mid < x) return ask(k << 1 | 1, x, y);
	else{
		int tmp = 1;
		if (t[k << 1].rc ^ t[k << 1 | 1].lc) tmp = 0;
		return ask(k << 1, x, mid) + ask(k << 1 | 1, mid + 1, y) - tmp;
	}
}

int getc(int k, int x){
	pushdown(k);
	int l = t[k].l, r = t[k].r;
	if (l == r) return t[k].lc;
	int mid = (l + r) >> 1;
	if (x <= mid) return getc(k << 1, x);
	else return getc(k << 1 | 1, x);
}

int solvesum(int x, int f){
	int sum = 0;
	while (belong[x] != belong[f]){
		sum += ask(1, pl[belong[x]], pl[x]);
		if (getc(1, pl[belong[x]]) == getc(1, pl[st[belong[x]][0]])) --sum;
		x = st[belong[x]][0];
	}
	
	sum += ask(1, pl[f], pl[x]);
	return sum;
}

void solvechange(int x, int f, int c){
	while (belong[x] != belong[f]){
		change(1, pl[belong[x]], pl[x], c);
		x = st[belong[x]][0];
	}
	
	change(1, pl[f], pl[x], c);
}


void solve(){
	int a, b, c;
	dfs1(1, 0);
	dfs2(1, 1);
	build(1, 1, n);
	
	rep(i, 1, n) change(1, pl[i], pl[i], v[i]);
	
	rep(i, 1, m){
		char ch[10];
		scanf("%s", ch);
		if (ch[0] == ‘Q‘){
			scanf("%d%d", &a, &b);
			int t = LCA(a, b);
			printf("%d\n", solvesum(a, t) + solvesum(b, t) - 1);
		}
		
		else{
			scanf("%d%d%d", &a, &b, &c);
			int t = LCA(a, b);
			solvechange(a, t, c);
			solvechange(b, t, c);
		}
	}
}

void init(){
	scanf("%d%d", &n, &m);
	rep(i, 1, n) scanf("%d", v + i);
	rep(i, 1, n - 1){
		int x, y;
		scanf("%d%d", &x, &y);
		ins(x, y);
	}
}

int main(){

	init();
	solve();
	return 0;
}

 

BZOJ2243 [SDOI2011]染色(树链剖分+线段树合并)