백준15852 ANTS

문제 링크

  • http://icpc.me/15852

문제 출처

  • 2017 자카르타 리저널 H번

사용 알고리즘

  • 트리 압축
  • 트리 DP

시간복잡도

  • $O(N \log N + \sum K \log N)$

풀이

전형적인 트리 압축 문제입니다.
트리 압축을 한 뒤, BOJ 15647 로스팅하는 엠마도 바리스타입니다처럼 Tree DP를 하면 됩니다.

전체 코드

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
/******************************
Author: jhnah917(Justice_Hui)
g++ -std=c++17 -DLOCAL
******************************/

#include <bits/stdc++.h>
#define x first
#define y second
#define all(v) v.begin(), v.end()
#define compress(v) sort(all(v)), v.erase(unique(all(v)), v.end())
using namespace std;

using uint = unsigned;
using ll = long long;
using ull = unsigned long long;
using PLL = pair<ll, ll>;

int N, Q, Par[18][101010], Dep[101010], In[101010], Out[101010], pv;
vector<int> G[101010];

void DFS(int v, int b=-1){
    In[v] = ++pv;
    for(const auto i : G[v]) if(i != b) Par[0][i] = v, Dep[i] = Dep[v] + 1, DFS(i, v);
    Out[v] = pv;
}
int LCA(int u, int v){
    if(Dep[u] < Dep[v]) swap(u, v);
    int diff = Dep[u] - Dep[v];
    for(int i=0; diff; i++, diff>>=1) if(diff & 1) u = Par[i][u];
    if(u == v) return u;
    for(int i=17; ~i; i--) if(Par[i][u] != Par[i][v]) u = Par[i][u], v = Par[i][v];
    return Par[0][u];
}

int K, C[101010], S[101010], Dist[101010];
ll DP[101010];
vector<PLL> T[101010];
vector<int> vertices;

void CompressTree(int v){
    while(pv < vertices.size() && In[vertices[pv]] <= Out[v]){
        int i = vertices[pv++];
        T[v].emplace_back(i, Dep[i] - Dep[v]);
        CompressTree(i);
    }
}
void DFS1(int v){
    S[v] = C[v];
    DP[1] += C[v] * Dist[v];
    for(const auto &[i,c] : T[v]){
        Dist[i] = Dist[v] + c;
        DFS1(i); S[v] += S[i];
    }
}
void DFS2(int v){
    for(const auto &[i,c] : T[v]){
        DP[i] = DP[v] + (K - 2*S[i]) * c;
        DFS2(i);
    }
}

void Solve(){
    cin >> K; vertices.resize(K);
    for(auto &i : vertices) cin >> i, C[i]++;
    sort(all(vertices), [](int u, int v){ return In[u] < In[v]; });
    vertices.erase(unique(all(vertices)), vertices.end());
    for(int i=1; i<K; i++) vertices.push_back(LCA(vertices[i-1], vertices[i]));
    vertices.push_back(1);
    sort(all(vertices), [](int u, int v){ return In[u] < In[v]; });
    vertices.erase(unique(all(vertices)), vertices.end());

    pv = 1;
    CompressTree(1);
    DFS1(1);
    DFS2(1);

    ll R = 1e18;
    for(const auto &i : vertices) R = min(R, DP[i]);
    cout << R << "\n";

    for(const auto i : vertices) C[i] = S[i] = Dist[i] = DP[i] = 0, T[i].clear();
}

int main(){
    ios_base::sync_with_stdio(false); cin.tie(nullptr);
    cin >> N;
    for(int i=1; i<N; i++){
        int s, e; cin >> s >> e;
        G[s].push_back(e); G[e].push_back(s);
    }
    DFS(1);
    for(int i=1; i<18; i++) for(int j=1; j<=N; j++) Par[i][j] = Par[i-1][Par[i-1][j]];
    cin >> Q;
    for(int i=1; i<=Q; i++) Solve();
}