백준14176 트리와 소수

문제 링크

  • http://icpc.me/14176

사용 알고리즘

  • Centroid Decomposition
  • FFT

풀이

cnt[k] = 길이가 k인 경로의 개수 라고 정의합시다.
cnt[1]부터 cnt[N-1]까지 모두 구해주면 문제를 풀 수 있습니다.

$O(N^2)$가지의 경로를 모두 고려해야 하기 때문에 센트로이드를 생각할 수 있습니다.

centroid $C$를 지나는 모든 경로들을 고려해봅시다.
어떤 경로가 $C$를 지난다는 것은, $C$를 제거했을 때 나눠지는 서브트리 $T_1, T_2, \ldots , T_x$가 있을 때 서로 다른 서브트리 $T_i, T_j$에서 정점을 하나씩 선택해서 이은 경로를 의미합니다.

$T_i$에 있는 정점 중 센트로이드 $C$와 거리가 $k$만큼 떨어진 정점의 개수를 subtree[i][k] 라고 합시다.
$T_i$에 있는 정점 하나와 $T_j$에 있는 정점 하나를 이어서 만든 길이가 $k$인 경로의 개수는 subtree[i][a] × subtree[j][k-a] 로 계산할 수 있습니다. 그러므로 두 개의 서브트리를 선택해서 모든 경로의 길이를 고려해주는 것은 convolusion의 형태로 나타낼 수 있고, 이는 FFT로 계산해줄 수 있습니다.

즉, $T_1$부터 $T_x$ 까지 모두 보면서 (subtree[1] + subtree[2] + … + subtree[i-1])과 subtree[i]를 곱한 결과를 cnt에 누적시켜주면 됩니다.

Centroid Decomposition과 FFT를 열심히 구현하면 풀리는 문제입니다.

전체 코드

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
#include <bits/stdc++.h>
#define x first
#define y second
#define all(v) v.begin(), v.end()
#define compress(v) sort(all(v)), v.erase(unique(all(v)), v.end())
using namespace std;

typedef long long ll;

vector<int> prime;
bool isp[101010];
void sieve(int n=100000){
    memset(isp, 1, sizeof isp);
    for(int i=2;i<=n;i++){
        if(isp[i]) prime.push_back(i);
        for(auto j : prime){
            if(i*j > n) break;
            isp[i*j]=0;
            if(i%j==0) break;
        }
    }
}

// FFT
typedef complex<double> cpx;
void fft(vector<cpx> &a, bool inv){
    int n = a.size(), j = 0;
    vector<cpx> roots(n/2);
    for(int i=1; i<n; i++){
        int bit = (n >> 1);
        while(j >= bit) j -= bit, bit >>= 1;
        j += bit;
        if(i < j) swap(a[i], a[j]);
    }
    double ang = 2 * acos(-1) / n * (inv ? -1 : 1);
    for(int i=0; i<n/2; i++) roots[i] = cpx(cos(ang * i), sin(ang * i));
    for(int i=2; i<=n; i<<=1){
        int step = n / i;
        for(int j=0; j<n; j+=i) for(int k=0; k<i/2; k++){
            cpx u = a[j+k], v = a[j+k+i/2] * roots[step * k];
            a[j+k] = u+v; a[j+k+i/2] = u-v;
        }
    }
    if(inv) for(int i=0; i<n; i++) a[i] /= n;
}
vector<ll> multiply(vector<ll> &v, vector<ll> &w){
    vector<cpx> fv(v.begin(), v.end()), fw(w.begin(), w.end());
    int n = 2; while(n < v.size() + w.size()) n <<= 1;
    fv.resize(n); fw.resize(n);
    fft(fv, 0); fft(fw, 0);
    for(int i=0; i<n; i++) fv[i] *= fw[i];
    fft(fv, 1);
    vector<ll> ret(n);
    for(int i=0; i<n; i++) ret[i] = (ll)round(fv[i].real());
    return ret;
}

ll n;
vector<int> g[101010];
int sz[101010], use[101010];
ll cnt[101010];

// Centroid
int getSize(int v, int p){
    sz[v] = 1;
    for(auto i : g[v]) if(i != p && !use[i]){
        sz[v] += getSize(i, v);
    }
    return sz[v];
}
int getCent(int v, int p, int s){
    for(auto i : g[v]) if(i != p && !use[i]){
            if(sz[i] > s/2) return getCent(i, v, s);
        }
    return v;
}

vector<ll> subtree, acc; int depth;

void update_sub(int v, int p, int d){
    depth = max(depth, d);
    subtree[d]++;
    for(auto i : g[v]) if(i != p && !use[i]) update_sub(i, v, d+1);
}

void calc(){
    auto t = multiply(subtree, acc);
    for(int i=1; i<t.size(); i++) cnt[i] += t[i];
}

void solve(int v){
    int cent = getCent(v, -1, getSize(v, -1));
    getSize(cent, -1);
    use[cent] = 1;
    acc.resize(1); acc.reserve(sz[cent]+1);
    acc[0] = 1;
    sort(all(g[cent]), [&](int a, int b){ return sz[a] < sz[b]; });

    for(auto i : g[cent]) if(!use[i]){
        depth = 0; subtree.clear(); subtree.resize(sz[i]+1);
        update_sub(i, cent, 1);
        calc();
        if(acc.size() <= depth) acc.resize(depth+1);
        for(int j=0; j<=depth; j++) acc[j] += subtree[j];
    }
    for(auto i : g[cent]) if(!use[i]) solve(i);
}

int main(){
    ios_base::sync_with_stdio(false); cin.tie(nullptr);
    cin >> n;
    for(int i=1; i<n; i++){
        int s, e; cin >> s >> e;
        g[s].push_back(e);
        g[e].push_back(s);
    }
    solve(1);
    sieve();
    ll sum = 0;
    for(auto i : prime) sum += cnt[i];
    cout << fixed; cout.precision(10);
    cout << 1.0 * sum / (n*(n-1)/2);
}