백준16121 사무실 이전

문제 링크

  • http://icpc.me/16121

사용 알고리즘

  • 센트로이드

시간복잡도

  • $O(N \log N)$

풀이

센트로이드 $c$를 지나는 모든 경로를 처리하고, 나머지 경로는 분할 정복을 통해 처리합시다.

두 정점 $u, v$를 잇는 경로가 $c$를 지난다는 것은, $c$를 기준으로 $u$와 $v$가 서로 다른 서브트리에서 왔다는 것을 의미합니다.
$c$를 루트로 잡아서 dfs를 돌렸을 때 $u, v$의 깊이가 각각 $d_u, d_v$라면, $u, v$를 잇는 경로의 길이의 제곱은 $(d_u + d_v)^2 = d_u^2 + 2d_ud_v + d_v^2$입니다.

서브트리 $T_1$에 있는 정점 $u_1, u_2, \dots, u_k$의 깊이가 각각 $d_{u_1}, d_{u_2}, \ldots, d_{u_k}$이고, 서브트리 $T_2$에 있는 정점 $v_1, v_2, \ldots, v_s$의 깊이가 각각 $d_{v_1}, d_v{v_2}, \ldots, d_{v_s}$일 때 $u_i$에서 $v_j$로 가는 모든 경로의 길이의 제곱의 합은 아래 식처럼 표현할 수 있습니다.

$\displaystyle s \times \sum_{i}(d_{u_i})^2 + t \times \sum_{j}(d_{v_j})^2 + 2 \times \sum_{i}d_{u_i} \times \sum_{j}d_{v_j}$

그러므로 각 서브트리마다 (직원의 수, 각 직원의 깊이의 합, 각 직원의 깊이 제곱의 합), (후보 역의 수, 각 후보 역의 깊이의 합, 각 후보 역의 깊이 제곱의 합)을 구한 뒤, $O((서브트리개수)^2)$만에 모든 경로의 결과를 계산해주면 됩니다.
정점이 $N$개인 트리의 정점의 차수가 최대 $N-1$인 것이 문제입니다.

트리에 최대 $N$개의 정점을 추가해 거리 관계를 유지하면서 이진 트리로 변환하는 테크닉이 있습니다. (아래에서 간략하게 설명합니다.) 이 방법을 적용해주면 이진트리에서 문제를 푸는 것이 되고, 이진 트리의 차수의 최댓값은 3이므로 $O(N \log N)$에 문제를 풀 수 있습니다.
센트로이드 자체도 하나의 서브트리로 생각하는 것이 구현이 편한 것 같습니다.

트리 이진 변환


어떤 정점의 자식의 개수가 3개 이상인 경우, 왼쪽에 자식 하나를 달고 오른쪽에 더미 정점을 넣은 뒤 그 더미 정점의 자식으로 옮겨주면 됩니다.
더미 정점으로 내려가는 간선의 가중치는 0, 다른 간선의 가중치는 원래 가중치를 넣어주면 거리 관계가 변하지 않습니다.

구현은 아래 코드의 make_binary함수를 참고하시면 됩니다.

전체 코드

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
#include <bits/stdc++.h>
#define x first
#define y second
#define all(v) v.begin(), v.end()
#define compress(v) sort(all(v)), v.erase(unique(all(v)), v.end())
#define int long long
using namespace std;

typedef long long ll;
typedef pair<int, int> p;
const ll mod = 998244353;

struct Info{
    ll hs, hs2, hc, ks, ks2, kc;
    Info() : Info(0, 0, 0, 0, 0, 0) {}
    Info(ll hs, ll hs2, ll hc, ll ks, ll ks2, ll kc) : hs(hs), hs2(hs2), hc(hc), ks(ks), ks2(ks2), kc(kc) {}
};

int n, m, k, pv;
vector<p> inp[303030], g[606060];
int cnt[606060], chk[606060];

// now, real node, prev, adj list start index
void make_binary(int v = 1, int real = 1, int b = -1, int idx = 0){
    for(int _i=idx; _i<inp[v].size(); _i++){ auto i = inp[v][_i];
        if(i.x == b) continue;
        if(g[real].empty()){
            g[real].emplace_back(i.x, i.y);
            make_binary(i.x, i.x, v);
            g[i.x].emplace_back(real, i.y);
            continue;
        }
        if(_i+1 == inp[v].size() || _i+2 == inp[v].size() && inp[v][_i+1].x == b){
            g[real].emplace_back(i.x, i.y);
            make_binary(i.x, i.x, v);
            g[i.x].emplace_back(real, i.y);
            continue;
        }
        int nxt = ++pv;
        g[real].emplace_back(nxt, 0);
        make_binary(v, pv, b, _i);
        g[nxt].emplace_back(real, 0);
        break;
    }
}

int dep[606060], use[606060], sz[606060];
int get_sz(int v, int b = -1){
    sz[v] = 1;
    for(auto i : g[v]) if(i.x != b && !use[i.x]) sz[v] += get_sz(i.x, v);
    return sz[v];
}
int get_cent(int v, int N, int b = -1){
    for(auto i : g[v]) if(i.x != b && !use[i.x] && sz[i.x]*2 > N) return get_cent(i.x, N, v);
    return v;
}

ll ans;

Info dfs(int v, int d = 1, int b = -1){
    Info ret; dep[v] = d;
    ll H = cnt[v] * dep[v] % mod, K = chk[v] * dep[v] % mod;
    ret.hs += H; ret.hs2 += H * H % mod; if(cnt[v]) ret.hc++;
    ret.ks += K; ret.ks2 += K * K % mod; if(chk[v]) ret.kc++;
    for(auto i : g[v]) if(i.x != b && !use[i.x]) {
        Info tmp = dfs(i.x, d+i.y, v);
        ret.hs += tmp.hs; ret.hs2 += tmp.hs2; ret.hc += tmp.hc;
        ret.ks += tmp.ks; ret.ks2 += tmp.ks2; ret.kc += tmp.kc;
        ret.hs %= mod; ret.hs2 %= mod;
        ret.ks %= mod; ret.ks2 %= mod;
    }
    return ret;
}

void solve(int v = 1, int b = -1){
    int cent = get_cent(v, get_sz(v)); use[cent] = 1;
    vector<Info> ret(1);
    if(cnt[cent]) ret[0].hc = 1;
    if(chk[cent]) ret[0].kc = 1;
    for(auto i : g[cent]) if(i.x != b && !use[i.x]) {
        dep[i.x] = 1; ret.push_back(dfs(i.x, i.y));
    }
    for(int i=0; i<ret.size(); i++) for(int j=0; j<ret.size(); j++) if(i != j) {
        ans += ret[i].hc * ret[j].ks2 % mod;
        ans += ret[i].kc * ret[j].hs2 % mod;
        ans += ret[i].hs * ret[j].ks * 2 % mod;
        ans %= mod;
    }
    for(auto i : g[cent]) if(i.x != b && !use[i.x]) solve(i.x, cent);
}

int32_t main(){
    ios_base::sync_with_stdio(false); cin.tie(nullptr);
    cin >> n; pv = n;
    for(int i=1; i<n; i++){
        int s, e; cin >> s >> e;
        inp[s].emplace_back(e, 1); inp[e].emplace_back(s, 1);
    }
    cin >> m; for(int i=1; i<=m; i++){ int t; cin >> t; cnt[t] = 1; }
    cin >> k; for(int i=1; i<=k; i++){ int t; cin >> t; chk[t] = 1; }
    make_binary(); n = pv;
    solve();
    cout << ans;
}