백준13514 트리와 쿼리 5

문제 링크

  • http://icpc.me/13514

사용 알고리즘

  • Centroid Decomposition

풀이

주어진 트리에서 Centroid Decomposition을 해서, Centroid Tree를 만들어주고 시작합시다.

set을 N개 관리할건데, i번 set에는 Centroid Tree상에서 루트와 i번 정점을 잇는 경로에 있는 흰색 정점과 i번 정점의 거리를 저장합니다.
Centroid Tree의 높이는 lgN이기 때문에, set에는 최대 lgN개의 정점이 들어가게 됩니다.

1
2
3
4
5
6
7
8
9
10
11
12
13
14
//cdtree[i] = centroid tree에서 i번 정점의 부모
void update(int v){
    color[v] = !color[v];
    int i = v;
    while(1){
        int dist = getDist(i, v);
        if(color[v]) st[i].insert(dist);
        else{
            st[i].erase(st[i].find(dist));
        }
        if(i == cdtree[i]) break;
        i = cdtree[i];
    }
}

N개의 set을 잘 구해놓았다면, 쿼리는 간단합니다.
트리에서 임의의 정점 u와 v를 잇는 경로는 Centroid Tree 상에서 u와 v의 lca를 기준으로 나눌 수 있습니다.
그러므로 Centroid Tree상의 v의 선조를 모두 살펴보면 v와 다른 노드 사이의 모든 경로의 정보를 알 수 있습니다.
위에서 언급했던 Centroid Tree의 성질(Centroid Tree의 높이 = lgN)에 의해, 각 쿼리는 set의 시간 복잡도까지 합쳐서 O(lgn * lglgn)에 처리할 수 있습니다.

전체 코드

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
#include <bits/stdc++.h>
#include <ext/pb_ds/assoc_container.hpp>
#include <ext/pb_ds/tree_policy.hpp>
#include <ext/rope>
#define x first
#define y second
#define all(v) v.begin(), v.end()
#define compress(v) sort(all(v)), v.erase(unique(all(v)), v.end())
#define pb push_back
using namespace std;
//using namespace __gnu_pbds; //ordered_set : find_by_order(order), order_of_key(key)
//using namespace __gnu_cxx; //crope : append(str), substr(s, e), at(idx)

typedef long long ll;
typedef unsigned long long ull;
typedef pair<ll, ll> p;
//typedef tree<int, null_type, less<int>, rb_tree_tag, tree_order_statistics_node_update> ordered_set;

const int di[] = {1, 0, -1, 0, 1, 1, -1, -1}, dj[] = {0, 1, 0, -1, 1, -1, 1, -1};
ll gcd(ll x, ll y) { return y ? gcd(y, x%y) : x; }
ll lcm(ll x, ll y) { return x / gcd(x, y) * y; }
ll mod(ll a, ll b) { return ((a%b) + b) % b; }
ll ext_gcd(ll a, ll b, ll &x, ll &y) { //ax + by = gcd(a, b)
  ll g = a; x = 1, y = 0;
  if (b) g = ext_gcd(b, a % b, y, x), y -= a / b * x;
  return g;
}
ll inv(ll a, ll m){ //return x when ax mod m = 1, fail -> -1
    ll x, y;
    ll g = ext_gcd(a, m, x, y);
    if(g > 1) return -1;
    return mod(x, m);
}
void finish(){ exit(0); }

int sz[101010], use[101010], dep[101010];
int dp[22][101010];
vector<int> g[101010];
int cdtree[101010];
multiset<int> st[101010];
int color[101010];

void dfs(int now, int prv = -1){
    dp[0][now] = prv;
    dep[now] = dep[prv] + 1;
    for(auto i : g[now]){
        if(i == prv) continue;
        dfs(i, now);
    }
}

int lca(int u, int v){
    if(dep[u] < dep[v]) swap(u, v);
    int diff = dep[u] - dep[v];
    for(int i=0; diff; i++){
        if(diff & 1) u = dp[i][u];
        diff >>= 1;
    }
    if(u == v) return u;
    for(int i=21; i>=0; i--){
        if(dp[i][u] ^ dp[i][v]) u = dp[i][u], v = dp[i][v];
    }
    return dp[0][u];
}

int getDist(int u, int v){
    return dep[u] + dep[v] - 2 * dep[lca(u, v)];
}

int getSize(int now, int prv = -1){
    sz[now] = 1;
    for(auto i : g[now]){
        if(use[i] || prv == i) continue;
        sz[now] += getSize(i, now);
    }
    return sz[now];
}

int getCent(int now, int prv, int cnt){
    for(auto i : g[now]){
        if(use[i] || i == prv) continue;
        if(sz[i] > cnt/2) return getCent(i, now, cnt);
    }
    return now;
}

void cd(int now, int prv = -1){
    int cnt = getSize(now, prv);
    int cent = getCent(now, prv, cnt);
    if(prv == -1) cdtree[cent] = cent;
    else cdtree[cent] = prv;
    use[cent] = 1;
    for(auto i : g[cent]){
        if(!use[i]) cd(i, cent);
    }
}

void update(int v){
    color[v] = !color[v];
    int i = v;
    while(1){
        int dist = getDist(i, v);
        if(color[v]) st[i].insert(dist);
        else{
            st[i].erase(st[i].find(dist));
        }
        if(i == cdtree[i]) break;
        i = cdtree[i];
    }
}

int query(int v){
    int i = v;
    int ret = 1e9+7;
    while(1){
        int dist = getDist(i, v);
        if(!st[i].empty()) ret = min(ret, dist + *st[i].begin());
        if(i == cdtree[i]) break;
        i = cdtree[i];
    }
    if(ret > 1e9) return -1;
    return ret;
}

int main(){
    ios_base::sync_with_stdio(0); cin.tie(0);
    int n; cin >> n;
    for(int i=1; i<n; i++){
        int s, e; cin >> s >> e;
        g[s].push_back(e); g[e].push_back(s);
    }
    dfs(1);
    for(int j=1; j<22; j++){
        for(int i=1; i<=n; i++){
            dp[j][i] = dp[j-1][dp[j-1][i]];
        }
    }
    cd(1);
    int q; cin >> q;
    while(q--){
        int op, v; cin >> op >> v;
        if(op == 1){
            update(v);

        }else{
            cout << query(v) << "\n";
        }
    }
}