백준15315 태풍의 아들 KDH

문제 링크

  • http://icpc.me/15315

사용 알고리즘

  • Centroid
  • FFT

풀이

경로의 길이를 $d$라고 했을 때, 해당 경로의 교통량의 기댓값은 $d \times (1-\frac{a}{b})^{d+1}$입니다.
cnt[k] = 길이가 k인 경로의 개수 로 정의한 다음 cnt[k]를 잘 구해주면 문제를 쉽게 풀 수 있습니다.

cnt배열을 구하는 것은 Centroid와 FFT로 할 수 있습니다.
BOJ14176 트리와 소수 문제의 풀이를 참고하시면 됩니다.

전체 코드

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
#pragma GCC target ("avx,avx2,fma")
#include <bits/stdc++.h>
#define x first
#define y second
#define all(v) v.begin(), v.end()
#define compress(v) sort(all(v)), v.erase(unique(all(v)), v.end())
using namespace std;

typedef long long ll;
const ll mod = 1e9+7;
const double pi = acos(-1);

ll pw(ll a, ll b){
    ll ret = 1;
    while(b){
        if(b & 1) ret = ret * a % mod;
        b >>= 1; a = a * a % mod;
    }
    return ret;
}

// FFT
#include <smmintrin.h>
#include <immintrin.h>
__m256d mult(__m256d a, __m256d b){
    __m256d c = _mm256_movedup_pd(a);
    __m256d d = _mm256_shuffle_pd(a, a, 15);
    __m256d cb = _mm256_mul_pd(c, b);
    __m256d db = _mm256_mul_pd(d, b);
    __m256d e = _mm256_shuffle_pd(db, db, 5);
    __m256d r = _mm256_addsub_pd(cb, e);
    return r;
}
void hell_joseon_fft(int n, __m128d a[], bool inv = false){
    for(int i=1, j=0; i<n; ++i){
        int bit = n>>1;
        for(;j>=bit;bit>>=1) j -= bit;
        j += bit;
        if(i<j) swap(a[i], a[j]);
    }
    for(int len=2; len<=n; len<<=1){
        double ang = 2*pi/len*(inv?-1:1);
        __m256d wlen; wlen[0] = cos(ang), wlen[1] = sin(ang);
        for(int i=0; i<n; i += len){
            __m256d w; w[0] = 1; w[1] = 0;
            for(int j=0; j<len/2; ++j){
                w = _mm256_permute2f128_pd(w, w, 0);
                wlen = _mm256_insertf128_pd(wlen, a[i+j+len/2], 1);
                w = mult(w, wlen);
                __m128d vw = _mm256_extractf128_pd(w, 1);
                __m128d u = a[i+j];
                a[i+j] = _mm_add_pd(u, vw);
                a[i+j+len/2] = _mm_sub_pd(u, vw);
            }
        }
    }
    if(inv){
        __m128d inv; inv[0] = inv[1] = 1.0/n;
        for(int i=0; i<n; ++i) a[i] = _mm_mul_pd(a[i], inv);
    }
}
vector<ll> multiply(vector<ll>& v, vector<ll>& w){
    int n = 2; while(n < v.size()+w.size()) n<<=1;
    __m128d* fv = new __m128d[n];
    for(int i=0; i<n; ++i) fv[i][0] = fv[i][1] = 0;
    for(int i=0; i<v.size(); ++i) fv[i][0] = v[i];
    for(int i=0; i<w.size(); ++i) fv[i][1] = w[i];
    hell_joseon_fft(n, fv); // (a+bi) is stored in FFT
    for(int i=0; i<n; i += 2){
        __m256d a;
        a = _mm256_insertf128_pd(a, fv[i], 0);
        a = _mm256_insertf128_pd(a, fv[i+1], 1);
        a = mult(a, a);
        fv[i] = _mm256_extractf128_pd(a, 0);
        fv[i+1] = _mm256_extractf128_pd(a, 1);
    }
    hell_joseon_fft(n, fv, 1);
    vector<ll> ret(n);
    for(int i=0; i<n; ++i) ret[i] = (ll)round(fv[i][1]/2);
    delete[] fv;
    return ret;
}

ll n, a, b;
vector<int> g[202020];
int sz[202020], use[202020];
ll cnt[202020];

// Centroid
int getSize(int v, int p){
    sz[v] = 1;
    for(auto i : g[v]) if(i != p && !use[i]){
            sz[v] += getSize(i, v);
        }
    return sz[v];
}
int getCent(int v, int p, int s){
    for(auto i : g[v]) if(i != p && !use[i]){
            if(sz[i] > s/2) return getCent(i, v, s);
        }
    return v;
}

vector<ll> subtree, acc; int depth;

void update_sub(int v, int p, int d){
    depth = max(depth, d);
    subtree[d]++;
    for(auto i : g[v]) if(i != p && !use[i]) update_sub(i, v, d+1);
}

void calc(){
    auto t = multiply(subtree, acc);
    for(int i=1; i<t.size(); i++) cnt[i] += t[i];
}

void solve(int v){
    int cent = getCent(v, -1, getSize(v, -1));
    getSize(cent, -1);
    use[cent] = 1;
    acc.resize(1); acc.reserve(sz[cent]+1);
    acc[0] = 1;
    sort(all(g[cent]), [&](int a, int b){ return sz[a] < sz[b]; });

    for(auto i : g[cent]) if(!use[i]){
            depth = 0; subtree.clear(); subtree.resize(sz[i]+1);
            update_sub(i, cent, 1);
            calc();
            if(acc.size() <= depth) acc.resize(depth+1);
            for(int j=0; j<=depth; j++) acc[j] += subtree[j];
        }
    for(auto i : g[cent]) if(!use[i]) solve(i);
}

int main(){
    ios_base::sync_with_stdio(false); cin.tie(nullptr);
    cin >> n >> a >> b; a = b - a;
    for(int i=1; i<n; i++){
        int s, e; cin >> s >> e;
        g[s].push_back(e);
        g[e].push_back(s);
    }
    solve(1);

    ll ans = 0;
    for(int i=1; i<=n; i++){
        ll t = 1LL * i * cnt[i] % mod;
        ll aa = pw(a, i+1);
        ll tt = pw(b, n-i);
        aa = aa * tt % mod;
        t = t * aa % mod;
        ans = (ans + t) % mod;
    }
    ans = ans * pw(pw(b, n+1), mod-2) % mod;
    cout << ans;
}