首页 > 代码库 > [BZOJ4765]普通计算姬

[BZOJ4765]普通计算姬

[BZOJ4765]普通计算姬

试题描述

"奋战三星期,造台计算机"。小G响应号召,花了三小时造了台普通计算姬。普通计算姬比普通计算机要厉害一些。普通计算机能计算数列区间和,而普通计算姬能计算树中子树和。更具体地,小G的计算姬可以解决这么个问题:给定一棵n个节点的带权树,节点编号为1到n,以root为根,设sum[p]表示以点p为根的这棵子树中所有节点的权值和。计算姬支持下列两种操作:
1 给定两个整数u,v,修改点u的权值为v。
2 给定两个整数l,r,计算sum[l]+sum[l+1]+....+sum[r-1]+sum[r]
尽管计算姬可以很快完成这个问题,可是小G并不知道它的答案是否正确,你能帮助他吗?

输入

第一行两个整数n,m,表示树的节点数与操作次数。
接下来一行n个整数,第i个整数di表示点i的初始权值。
接下来n行每行两个整数ai,bi,表示一条树上的边,若ai=0则说明bi是根。
接下来m行每行三个整数,第一个整数op表示操作类型。
若op=1则接下来两个整数u,v表示将点u的权值修改为v。
若op=2则接下来两个整数l,r表示询问。
N<=10^5,M<=10^5
0<=Di,V<2^31,1<=L<=R<=N,1<=U<=N

输出

对每个操作类型2输出一行一个整数表示答案。

输入示例

6 4
0 0 3 4 0 1
0 1
1 2
2 3
2 4
3 5
5 6
2 1 2
1 1 1
2 3 6
2 3 5

输出示例

16
10
9

数据规模及约定

见“输入

题解

做法一:

我们先把 sum 数组求出来,然后节点 u 权值的修改对应着 u 到根节点这一条链的修改,于是可以树链剖分套数据结构完成这个操作;对于询问,它问的却是连续的一段编号;所以可以看成一个二维平面,每个节点 u 的坐标是 (dfs[u], u)(dfs[u] 表示节点 u 的树链剖分序),权值就是 sum[u];那么对于修改操作,就是 x 轴上连续的 log(n) 段,y 坐标没有限制;对于询问操作就是 y 轴上连续一段, x 坐标没限制;所以这个东西就可以用 kd 树维护了。

#include <iostream>
#include <cstdio>
#include <cstring>
#include <cstdlib>
#include <cctype>
#include <algorithm>
using namespace std;

const int BufferSize = 1 << 16;
char buffer[BufferSize], *Head, *Tail;
inline char Getchar() {
    if(Head == Tail) {
        int l = fread(buffer, 1, BufferSize, stdin);
        Tail = (Head = buffer) + l;
    }
    return *Head++;
}
int read() {
	int x = 0, f = 1; char c = Getchar();
	while(!isdigit(c)){ if(c == ‘-‘) f = -1; c = Getchar(); }
	while(isdigit(c)){ x = x * 10 + c - ‘0‘; c = Getchar(); }
	return x * f;
}

#define maxn 100010
#define maxm 200010
#define UL unsigned long long
#define LL long long

int n, m, head[maxn], nxt[maxm], to[maxm], val[maxn];

void AddEdge(int a, int b) {
	to[++m] = b; nxt[m] = head[a]; head[a] = m;
	swap(a, b);
	to[++m] = b; nxt[m] = head[a]; head[a] = m;
	return ;
}

int rt, fa[maxn], siz[maxn], son[maxn], top[maxn], pos[maxn], ToT;
UL sum[maxn];
void build(int u) {
	siz[u] = 1; sum[u] = val[u];
	for(int e = head[u]; e; e = nxt[e]) if(to[e] != fa[u]) {
		fa[to[e]] = u;
		build(to[e]);
		siz[u] += siz[to[e]]; sum[u] += sum[to[e]];
		if(!son[u] || siz[son[u]] < siz[to[e]]) son[u] = to[e];
	}
	return ;
}
void gett(int u, int tp) {
	top[u] = tp; pos[u] = ++ToT;
	if(son[u]) gett(son[u], tp);
	for(int e = head[u]; e; e = nxt[e])
		if(to[e] != fa[u] && to[e] != son[u]) gett(to[e], to[e]);
	return ;
}

int Rt, ch[maxn][2];
bool Cur;
struct Node {
	int x[2], mx[2], mn[2], siz; UL sum; LL val, add;
	
	Node() {}
	Node(int x, int y, LL val): val(val), add(0) {
		this->x[0] = x; this->x[1] = y;
	}
	bool operator < (const Node& t) const { return x[Cur] != t.x[Cur] ? x[Cur] < t.x[Cur] : x[Cur^1] < t.x[Cur^1]; }
} ns[maxn];
void maintain(int o) {
	for(int j = 0; j < 2; j++) ns[o].mx[j] = ns[o].mn[j] = ns[o].x[j];
	ns[o].sum = ns[o].val; ns[o].siz = 1;
	for(int i = 0; i < 2; i++) if(ch[o][i]) {
		for(int j = 0; j < 2; j++)
			ns[o].mx[j] = max(ns[o].mx[j], ns[ch[o][i]].mx[j]),
			ns[o].mn[j] = min(ns[o].mn[j], ns[ch[o][i]].mn[j]);
		ns[o].sum += ns[ch[o][i]].sum;
		ns[o].siz += ns[ch[o][i]].siz;
	}
	ns[o].val += ns[o].add;
	ns[o].sum += ns[o].add * ns[o].siz;
	return ;
}
void build(int& o, int l, int r, bool cur) {
	if(l > r) return ;
	int mid = l + r >> 1; o = mid;
	Cur = cur; nth_element(ns + l, ns + mid, ns + r + 1);
	build(ch[o][0], l, mid - 1, cur ^ 1); build(ch[o][1], mid + 1, r, cur ^ 1);
	return maintain(o);
}
void pushdown(int o) {
	if(!ns[o].add) return ;
	for(int i = 0; i < 2; i++) if(ch[o][i]) {
		ns[ch[o][i]].add += ns[o].add;
		ns[ch[o][i]].val += ns[o].add;
		ns[ch[o][i]].sum += ns[o].add * ns[ch[o][i]].siz;
	}
	ns[o].add = 0;
	return ;
}
void upd(int o, int l, int r, int add) {
	pushdown(o);
	if(l <= ns[o].mn[0] && ns[o].mx[0] <= r) {
		ns[o].add += add;
		ns[o].val += add;
		ns[o].sum += (LL)add * ns[o].siz;
		return ;
	}
	if(l <= ns[o].x[0] && ns[o].x[0] <= r) ns[o].val += add;
	for(int i = 0; i < 2; i++)
		if(ch[o][i] && l <= ns[ch[o][i]].mx[0] && ns[ch[o][i]].mn[0] <= r)
			upd(ch[o][i], l, r, add);
	return maintain(o);
}
UL que(int o, int l, int r) {
	pushdown(o);
	if(l <= ns[o].mn[1] && ns[o].mx[1] <= r) return ns[o].sum;
	UL ans = (l <= ns[o].x[1] && ns[o].x[1] <= r) ? ns[o].val : 0;
	for(int i = 0; i < 2; i++)
		if(ch[o][i] && l <= ns[ch[o][i]].mx[1] && ns[ch[o][i]].mn[1] <= r)
			ans += que(ch[o][i], l, r);
	return ans;
}

void update(int u, int add) {
	while(u) upd(Rt, pos[top[u]], pos[u], add), u = fa[top[u]];
	return ;
}

#define maxol 2100000
char Output[maxol];
int num[21], cnt, cntol;

int main() {
//	freopen("common10.in", "r", stdin);
//	freopen("data.out", "w", stdout);
	n = read(); int q = read();
	for(int i = 1; i <= n; i++) val[i] = read();
	
	for(int i = 1; i <= n; i++) {
		int a = read(), b = read();
		if(!a) rt = b;
		else AddEdge(a, b);
	}
	build(rt);
	gett(rt, rt);
	
	for(int i = 1; i <= n; i++) ns[i] = Node(pos[i], i, sum[i]);
	build(Rt, 1, n, 0);
	while(q--) {
		int tp = read();
		if(tp == 1) {
			int u = read(), v = read();
			update(u, v - val[u]); val[u] = v;
		}
		else {
			int l = read(), r = read();
			UL tmp = que(Rt, l, r);
			cnt = 0;
			while(tmp) num[cnt++] = tmp % 10, tmp /= 10;
			for(int i = cnt - 1; i >= 0; i--) Output[cntol++] = num[i] + ‘0‘;
			Output[cntol++] = ‘\n‘;
		}
	}
	Output[--cntol] = ‘\0‘;
	puts(Output);
	
	return 0;
}

然而加了读入输出优化还是 T 飞。。。。。

解法二:

分块套分块。

我们先搞一个 dfs 序列,序列上存 val[i](即节点 i 的权值)。然后我们对这个序列分块,并维护两个信息:dfS[i] 表示位置 i 所在块的前缀和,dfSb[i] 表示前 i 个块的总和(即块的前缀和)。这样我们就可以 O(1) 询问区间和,O(sqrt(n)) 点修改了(想一想,为什么)。

然后我们再对正常编号的序列分块,并维护两个信息:tot[i][j] 表示第 i 块中所有 sum 使得对应 dfs 序列上位置 j 被计算了几次,Sb[i] 表示第 i 块中 sum 的总和。那么借助 tot 我们可以 O(n · sqrt(n)) 预处理 Sb,还可以 O(sqrt(n)) 支持点修改(想一想,为什么)。查询 [l, r] 时,对于被整个覆盖的块 i 直接累加 Sb[i] 就好了,对于没有被整个覆盖的块我们暴力找到这些点对应的 dfs 序上的区间,然后 O(1) 询问区间和,累加,就好了。

#include <iostream>
#include <cstdio>
#include <cstdlib>
#include <cstring>
#include <cctype>
#include <algorithm>
#include <cmath>
using namespace std;

int read() {
	int x = 0, f = 1; char c = getchar();
	while(!isdigit(c)){ if(c == ‘-‘) f = -1; c = getchar(); }
	while(isdigit(c)){ x = x * 10 + c - ‘0‘; c = getchar(); }
	return x * f;
}

#define maxn 100010
#define maxm 200010
#define maxb 320
#define UL unsigned long long
#define LL long long

int n, m, head[maxn], nxt[maxm], to[maxm];

void AddEdge(int a, int b) {
	to[++m] = b; nxt[m] = head[a]; head[a] = m;
	swap(a, b);
	to[++m] = b; nxt[m] = head[a]; head[a] = m;
	return ;
}

int rt, clo, dl[maxn], dr[maxn], id[maxn], val[maxn];
void build(int u, int fa) {
	dl[u] = ++clo; id[clo] = u;
	for(int e = head[u]; e; e = nxt[e]) if(to[e] != fa)
		build(to[e], u);
	dr[u] = clo;
	return ;
}

UL dfS[maxn], dfSb[maxb], Sb[maxb];
int tot[maxb][maxn], bl[maxn], st[maxn], en[maxn];

UL que(int l, int r) {
	return dfS[r] - (l > st[bl[l]] ? dfS[l-1] : 0) + dfSb[bl[r]-1] - dfSb[bl[l]-1];
}

int main() {
	n = read(); int q = read();
	for(int i = 1; i <= n; i++) val[i] = read();
	for(int i = 1; i <= n; i++) {
		int a = read(), b = read();
		if(!a) rt = b;
		else AddEdge(a, b);
	}
	
	build(rt, 0);
	int m = (int)sqrt(n);
	for(int i = 1; i <= n; i++) {
		bl[i] = (i - 1) / m + 1; if(!st[bl[i]]) st[bl[i]] = i; en[bl[i]] = i;
		dfS[i] = (bl[i-1] == bl[i] ? dfS[i-1] : 0) + (UL)val[id[i]];
		if(bl[i] != bl[i-1]) dfSb[bl[i]] = dfSb[bl[i-1]];
		dfSb[bl[i]] += val[id[i]];
	}
//	for(int i = 1; i <= bl[n]; i++) printf("[%d, %d]\n", st[i], en[i]);
	for(int i = 1; i <= bl[n]; i++) {
		for(int j = st[i]; j <= en[i]; j++) tot[i][dl[j]]++, tot[i][dr[j]+1]--;
		for(int j = 1; j <= n; j++)
			tot[i][j] += tot[i][j-1], Sb[i] += (UL)tot[i][j] * val[id[j]];
	}
	
	while(q--) {
		int tp = read();
		if(tp == 1) {
			int u = read(), v = read(), dv = v - val[u];
			val[u] = v;
			for(int i = dl[u]; i <= en[bl[dl[u]]]; i++) dfS[i] += dv;
			for(int i = bl[dl[u]]; i <= bl[n]; i++) dfSb[i] += dv;
			for(int i = 1; i <= bl[n]; i++) Sb[i] += (UL)tot[i][dl[u]] * dv;
		}
		else {
			int l = read(), r = read();
			UL ans = 0;
			if(bl[l] == bl[r]) for(int i = l; i <= r; i++) ans += que(dl[i], dr[i]);
			else {
				for(int i = bl[l] + 1; i < bl[r]; i++) ans += Sb[i];
				for(int i = l; i <= en[bl[l]]; i++) ans += que(dl[i], dr[i]);
				for(int i = st[bl[r]]; i <= r; i++) ans += que(dl[i], dr[i]);
			}
			printf("%llu\n", ans);
		}
	}
	
	return 0;
}

 

[BZOJ4765]普通计算姬