백준20030 트리와 쿼리 17

문제 링크

  • http://icpc.me/20030

사용 알고리즘

  • Segment Tree & Lazy Propagation
  • HLD & Euler Tour Trick
  • Sparse Table
  • Parametric Search

시간복잡도

  • $O(N + Q \log^2 N)$

풀이

정점의 가중치가 바뀌는 상황에서 Weighted Centroid를 구하는 문제입니다.

1번 쿼리는 Euler Tour Trick으로, 2번 쿼리는 HLD로 처리할 수 있습니다.
Euler Tour Trick을 이용해 트리를 선형으로 폅시다. Weighted Centroid의 서브트리는 항상 median 지점을 포함합니다. 그러므로 median 지점에서 시작해 조상으로 올라가면서 Centroid를 찾으면 됩니다.
Naive하게 올라가면 N개의 정점을 탐색하기 때문에 비효율적입니다. Sparse Table을 이용한 Parametric Search를 이용하면 $O(\log N)$개의 정점만 봐도 됩니다. 자세한 구현은 아래 코드를 참고해주세요.

전체 코드

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
#include <bits/stdc++.h>
#define x first
#define y second
#define all(v) v.begin(), v.end()
#define compress(v) sort(all(v)), v.erase(unique(all(v)), v.end())
using namespace std;

typedef long long ll;

struct Seg{
    ll tree[1 << 18], tmp[1 << 18];
    int n;
    void init(int N){ n = N; }
    void push(int node, int s, int e){
        if(!tmp[node]) return;
        tree[node] += (e-s+1) * tmp[node];
        if(s != e){
            tmp[node << 1] += tmp[node];
            tmp[node << 1 | 1] += tmp[node];
        }
        tmp[node] = 0;
    }
    void update(int node, int s, int e, int l, int r, ll v){
        push(node, s, e);
        if(r < s || e < l) return;
        if(l <= s && e <= r){
            tmp[node] += v; push(node, s, e);
            return;
        }
        int m = s + e >> 1;
        update(node << 1, s, m, l, r, v);
        update(node << 1 | 1, m+1, e, l, r, v);
        tree[node] = tree[node << 1] + tree[node << 1 | 1];
    }
    ll query(int node, int s, int e, int l, int r){
        push(node, s, e);
        if(r < s || e < l) return 0;
        if(l <= s && e <= r) return tree[node];
        int m = s + e >> 1;
        return query(node << 1, s, m, l, r) + query(node << 1 | 1, m+1, e, l, r);
    }
    int kth(int node, int s, int e, ll k){
        push(node, s, e);
        if(s == e) return s;
        int m = s + e >> 1;
        push(node << 1, s, m);
        push(node << 1 | 1, m+1, e);
        if(k <= tree[node << 1]) return kth(node << 1, s, m, k);
        else return kth(node << 1 | 1, m+1, e, k-tree[node << 1]);
    }
    void update(int s, int e, ll x = 1){ update(1, 1, n, s, e, x); }
    ll query(int s, int e){ return query(1, 1, n, s, e); }
    int kth(ll k){ return kth(1, 1, n, k); }
} seg;

int n, q; ll sum;
vector<int> inp[101010], g[101010];
int dep[101010], par[101010], sz[101010], top[101010], dp[22][101010];
int in[101010], out[101010], rev[101010], pv;
void dfs(int v, int b = -1){
    for(auto i : inp[v]) if(i != b){
        dep[i] = dep[v] + 1; par[i] = v; dp[0][i] = v;
        g[v].push_back(i);
        dfs(i, v);
    }
}
void dfs1(int v){
    sz[v] = 1;
    for(auto &i : g[v]){
        dfs1(i); sz[v] += sz[i];
        if(sz[i] > sz[g[v][0]]) swap(g[v][0], i);
    }
}
void dfs2(int v){
    in[v] = ++pv; rev[pv] = v;
    for(auto i : g[v]){
        top[i] = i == g[v][0] ? top[v] : i;
        dfs2(i);
    }
    out[v] = pv;
}
void updateSub(int v){
    seg.update(in[v], out[v]);
    sum += out[v] - in[v] + 1;
}
void updatePath(int u, int v){
    for(; top[u] != top[v]; u=par[top[u]]){
        if(dep[top[u]] < dep[top[v]]) swap(u, v);
        seg.update(in[top[u]], in[u]);
        sum += in[u] - in[top[u]] + 1;
    }
    if(dep[u] > dep[v]) swap(u, v);
    seg.update(in[u], in[v]);
    sum += in[v] - in[u] + 1;
}

int query(){
    int v = rev[seg.kth((sum+1)/2)];
    if(seg.query(in[v], out[v])*2 > sum) return v;
    for(int i=21; ~i; i--) if(dp[i][v]) {
        if(seg.query(in[dp[i][v]], out[dp[i][v]])*2 <= sum) v = dp[i][v];
    }
    return dp[0][v];
}

int main(){
    ios_base::sync_with_stdio(false); cin.tie(nullptr);
    cin >> n; seg.init(n);
    for(int i=1; i<n; i++){
        int s, e; cin >> s >> e;
        inp[s].push_back(e); inp[e].push_back(s);
    }
    top[1] = 1; dfs(1); dfs1(1); dfs2(1);
    for(int i=1; i<22; i++) for(int j=1; j<=n; j++) dp[i][j] = dp[i-1][dp[i-1][j]];
    cin >> q;
    while(q--){
        int op, a, b; cin >> op;
        if(op == 1){ cin >> a; updateSub(a); }
        else{ cin >> a >> b; updatePath(a, b); }
        cout << query() << "\n";
    }
}