백준10014 Traveling Saga Problem

문제 링크

  • http://icpc.me/10014

사용 알고리즘

  • Centroid
  • 트리 이진 변환

시간복잡도

  • $O((N+Q) \log^2 N)$

풀이

어떤 정점에서 가장 먼 정점을 찾는 쿼리를 여러 번 수행하는 문제입니다. Centroid Decomposition을 이용하면 될 것 같다는 생각이 듭니다.

입력으로 주어진 트리 $T$의 정점 $v$와 가장 멀리 떨어져있는 정점 $u$는, Centroid Tree $C_T$에서 $v$의 조상 $p$를 따라가면서 $p$의 $C_T$ 상에서의 자손을 보면 됩니다. 이때 $C_T$의 각 정점에서는, 그 서브 트리에 속하는 정점 $x$에 대해 $(Dist(p, x), x)$를 max heap으로 저장하고 있습니다.
$u$는 $v$와 다른 서브트리에 속해야 하기 때문에, $v$의 $C_T$ 상의 조상 $p$의 자식 정점을 모두 순회해야 합니다.

만약 트리가 성게 그래프(star graph)라면 Centroid Tree의 정점의 Degree가 $N-1$이 돼서 각 쿼리가 $O(N \log^2 N)$이 됩니다. 트리를 이진 트리로 바꿔주면 자식의 개수(정점의 Degree)가 최대 3이 되므로 각 쿼리를 $O(3 \log^2 N)$에 처리할 수 있습니다.

전체 코드

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
/******************************
Author: jhnah917(Justice_Hui)
g++ -std=c++17 -DLOCAL
******************************/

#include <bits/stdc++.h>
#define x first
#define y second
#define all(v) v.begin(), v.end()
#define compress(v) sort(all(v)), v.erase(unique(all(v)), v.end())
#define IDX(v, x) (lower_bound(all(v), x) - v.begin())
using namespace std;

using uint = unsigned;
using ll = long long;
using ull = unsigned long long;
using PII = pair<int, int>;
constexpr int SZ = 505050, LG = 20, INF = 0x3f3f3f3f;

int N, pv;
vector<int> Inp[SZ];
vector<PII> G[SZ];
int S[SZ], U[SZ], P[LG][SZ], D[SZ], C[SZ];
int CP[SZ], CID[SZ];
priority_queue<PII> Info[SZ][4];

void Flush(priority_queue<PII> &pq){
    while(pq.size() && U[pq.top().y]) pq.pop();
}

void AddEdge(int s, int e, int x){
    G[s].emplace_back(e, x);
    G[e].emplace_back(s, x);
}

void Rebuild(int v, int b=-1){
    int lst = -1;
    for(auto i : Inp[v]){
        if(i == b) continue;
        Rebuild(i, v);
        if(lst == -1) AddEdge(v, i, 1), lst = v;
        else{
            int dummy = ++pv;
            AddEdge(lst, dummy, 0);
            AddEdge(dummy, i, 1);
            lst = dummy;
        }
    }
}

void Init(int v, int b=-1){
    for(int i=1; i<LG; i++) P[i][v] = P[i-1][P[i-1][v]];
    for(auto i : G[v]){
        if(i.x == b) continue;
        P[0][i.x] = v;
        D[i.x] = D[v] + 1;
        C[i.x] = C[v] + i.y;
        Init(i.x, v);
    }
}

int LCA(int u, int v){
    if(D[u] < D[v]) swap(u, v);
    int diff = D[u] - D[v];
    for(int i=0; diff; i++, diff>>=1) if(diff & 1) u = P[i][u];
    if(u == v) return u;
    for(int i=LG-1; ~i; i--) if(P[i][u] != P[i][v]) u = P[i][u], v = P[i][v];
    return P[0][u];
}

int Dist(int u, int v){
    return C[u] + C[v] - 2*C[LCA(u, v)];
}

int GetSize(int v, int b=-1){
    S[v] = 1;
    for(auto i : G[v]) if(!U[i.x] && i.x != b) S[v] += GetSize(i.x, v);
    return S[v];
}

int GetCent(int v, int tot, int b=-1){
    for(auto i : G[v]) if(!U[i.x] && i.x != b && S[i.x]*2 > tot) return GetCent(i.x, tot, v);
    return v;
}

void Gather(int rt, int id, int v, int b, int d){
    if(v <= N) Info[rt][id].emplace(d, v);
    for(auto i : G[v]) if(!U[i.x] && i.x != b) Gather(rt, id, i.x, v, d+i.y);
}

int CD(int v, int b=-1){
    int cent = GetCent(v, GetSize(v));
    U[cent] = 1; CP[cent] = b;

    int cnt = 1;
    if(cent <= N) Info[cent][0].emplace(0, cent);
    for(auto i : G[cent]) if(!U[i.x]) Gather(cent, cnt++, i.x, cent, i.y);

    cnt = 1;
    for(auto i : G[cent]) if(!U[i.x]) CID[CD(i.x, cent)] = cnt++;
    return cent;
}

int Query(int v){
    PII mx(-INF, -INF);
    for(int i=v, j=0; i!=-1; j=CID[i], i=CP[i]){
        int c = Dist(v, i);
        for(int k=0; k<4; k++){
            Flush(Info[i][k]);
            if(j == k || Info[i][k].empty()) continue;
            auto [d,u] = Info[i][k].top();
            mx = max(mx, PII(d+c, u));
        }
    }
    return mx.y;
}

int main(){
    ios_base::sync_with_stdio(false); cin.tie(nullptr);
    cin >> N;
    for(int i=1; i<N; i++){
        int s, e; cin >> s >> e;
        Inp[s].push_back(e);
        Inp[e].push_back(s);
    }
    for(int i=1; i<=N+N; i++) G[i].reserve(3);

    pv = N; Rebuild(1);
    Init(1);
    CD(1);

    memset(U, 0, sizeof U);

    int lst = 1; U[lst] = 1;
    cout << lst << " ";
    for(int i=1; i<N; i++){
        lst = Query(lst); U[lst] = 1;
        cout << lst << " ";
    }
}