首页 > 代码库 > 校内训练0609 problem c

校内训练0609 problem c

【题目大意】

给一棵树,求有多少条路径满足总和-最大值 是P的倍数

n<=10^5, P<=10^7

【题解】

一看就是点分治嘛

不考虑子树合并,考虑poj1741的做法,每次考虑经过重心的路径,用优先队列,从小到达添加并求答案即可。

容斥下。

技术分享
# include <queue>
# include <stdio.h>
# include <string.h>
# include <iostream>
# include <algorithm>
// # include <bits/stdc++.h>

using namespace std;

typedef long long ll;
typedef long double ld;
typedef unsigned long long ull;
const int N = 1e5 + 10, M = 2e5 + 10, MAX = 1e7 + 5;
const int mod = 1e9+7;

# define RG register
# define ST static

int n, P, v[N];
int head[N], nxt[M], to[M], tot;
inline void add(int u, int v) {
    ++tot; nxt[tot] = head[u]; head[u] = tot; to[tot] = v;
}
inline void adde(int u, int v) {
    add(u, v), add(v, u);
}

namespace DFZ {
    bool vis[N];
    int sz[N], mx[N];
    inline void dfsSize(int x, int fa = 0) {
        sz[x] = 1, mx[x] = 0; 
        for (int i=head[x]; i; i=nxt[i]) {
            if(to[i] == fa || vis[to[i]]) continue;
            dfsSize(to[i], x);
            sz[x] += sz[to[i]];
            if(sz[to[i]] > mx[x]) mx[x] = sz[to[i]];
        }
    }
    int mi, centre;
    inline void dfsCentre(int x, int tp, int fa = 0) {
        if(sz[tp] - sz[x] > mx[x]) mx[x] = sz[tp] - sz[x];
        if(mx[x] < mi) mi = mx[x], centre = x;
        for (int i=head[x]; i; i=nxt[i]) {
            if(to[i] == fa || vis[to[i]]) continue;
            dfsCentre(to[i], tp, x);
        }
    }
    
    struct pa {
        int x, s, mx, fa;
        pa() {}
        pa(int x, int s, int mx, int fa) : x(x), s(s), mx(mx), fa(fa) {}
        friend bool operator < (pa a, pa b) {
            return a.mx > b.mx;
        }
    };
        
    priority_queue<pa> q;
    int buc[MAX];
    int st[M], stn;
    
    inline void delAns(int x, int s, int fa) {
        -- buc[s];
        for (int i=head[x]; i; i=nxt[i]) {
            if(to[i] == fa || vis[to[i]]) continue;
            delAns(to[i], (s + v[to[i]]) % P, x);
        }
    }
    
    inline ll doit(int x, int temp_s, int temp_mx, int temp_fa, int Vx) {
        ll ret = 0; (temp_s += v[x]) %= P; temp_mx = max(temp_mx, v[x]);
        while(!q.empty()) q.pop(); stn = 0;
        q.push(pa(x, temp_s, temp_mx, temp_fa));
        while(!q.empty()) {
            pa tp = q.top(); q.pop();
            // tp.s + S - Vx - mx = 0 (mod P)
            // S = Vx + mx - tp.s
            ret += buc[((tp.mx + Vx - tp.s) % P + P) % P];
            ++ buc[tp.s];
            st[++stn] = tp.s;
            for (int i=head[tp.x]; i; i=nxt[i]) {
                if(to[i] == tp.fa || vis[to[i]]) continue;
                q.push(pa(to[i], (tp.s + v[to[i]]) % P, max(tp.mx, v[to[i]]), tp.x));
            }
        }
        for (int i=stn; i; --i) -- buc[st[i]];
        return ret;
    }

    ll ans;
    inline void dfs(int x) {
        dfsSize(x); mi = n;
        dfsCentre(x, x);
        x = centre;
        // ===== //
        // printf("x = %d\n", x);
        ans += doit(x, 0, 0, 0, v[x]);    
        // ===== //
        vis[x] = 1;
        for (int i=head[x]; i; i=nxt[i])
            if(!vis[to[i]]) {
                ans -= doit(to[i], v[x], v[x], x, v[x]);
                dfs(to[i]);
            }
    }
    
    inline void main() {
        ans = 0;
        dfs(1);
        ans += n;
        cout << ans << endl;
    }
}
        

int main() {
    freopen("c.in", "r", stdin);
    freopen("c.out", "w", stdout);
    cin >> n >> P;
    for (int i=1, u, tv; i<n; ++i) {
        scanf("%d%d", &u, &tv);
        adde(u, tv);
    }
    for (int i=1; i<=n; ++i) scanf("%d", v+i);
    DFZ::main();
    return 0;
}
/*
5 2
1 2
1 3
2 4
3 5
1 3 3 1 2
*/
View Code

 

校内训练0609 problem c