백준16191 Utilitarianism

문제 링크

  • http://icpc.me/16191

사용 알고리즘

  • Alien’s Trick
  • Tree DP

시간복잡도

  • $O(N \log X)$

풀이

트리에서 매칭 K개를 골랐을 때 가중치 합의 최댓값을 구하는 문제입니다.

N = 100일 때의 풀이를 먼저 생각해봅시다.

네, 맞습니다. 이분 그래프에서 매칭 K개의 최소/최대 비용은 MCMF로 구할 수 있습니다. 그래프 모델링하고 플로우를 K만큼 흘려주면 됩니다.

이제 N = 100,000일 때의 풀이를 생각해봅시다.
MCMF로 답을 구할 수 있다는 것은, 답이 K에 대해 볼록하다는 것을 의미합니다. 다시 말해, K = t일 때의 정답을 f(t)라고 하면 (K, f(K)) 점들을 찍었을 때 위로 볼록한 형태의 그래프가 나옵니다.
Alien’s Trick을 이용해 쉽게 답을 구할 수 있습니다.

Challenge

트리에서 매칭을 1개, 2개, … , N-1개 골랐을 때 가중치 합의 최댓값을 모두 구하는 건 어떻게 해결할 수 있을까요?(문제)
Utilitarianism 풀이를 N번 수행하면 $O(N^2 \log X)$에 해결할 수 있습니다. 이것보다 빠르게 풀 수 있을까요?

전체 코드

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
#include <bits/stdc++.h>
#define x first
#define y second
using namespace std;

typedef long long ll;
typedef pair<ll, ll> p;
const ll small = 0xc0c0c0c0c0c0c0c0;

int n, k;
vector<p> inp[252525], g[252525];
ll dp[2][252525], cst;
int cnt[2][252525];

void dfs(int v, int b, int d = 1){
    ll sum = 0, sumcnt = 0;
    ll match = 0, matchcnt = 0;
    for(auto i :g[v]) if(i.x != b){
        dfs(i.x, v, d+1);
        sum += dp[1][d+1]; sumcnt += cnt[1][d+1];
        ll t1 = dp[0][d+1] - dp[1][d+1] + i.y - cst;
        ll t2 = cnt[0][d+1] - cnt[1][d+1] + 1;
        if(p(match, matchcnt) < p(t1, t2)){
            match = t1, matchcnt = t2;
        }
    }
    dp[0][d] = sum; cnt[0][d] = sumcnt;
    dp[1][d] = sum + match; cnt[1][d] = sumcnt + matchcnt;
}
int chk(ll c){ cst = c; dfs(1, -1); return cnt[1][1]; }
int impossible(){ cst = -1e12; dfs(1, -1); return cnt[1][1] + 1; }

int main(){
    ios_base::sync_with_stdio(0); cin.tie(0);
    cin >> n >> k;
    for(int i=1; i<n; i++){
        int s, e, x; cin >> s >> e >> x;
        g[s].push_back({e, x});
        g[e].push_back({s, x});
    }
    int ed = impossible();
    if(ed <= k){ cout << "Impossible"; return 0; };
    ll l = -1e12, r = 1e12;
    while(l < r){
        ll m = l + r + 1 >> 1;
        ll t = chk(m);
        if(t >= k) l = m;
        else r = m - 1;
    }
    chk(l);
    cout << max(dp[0][1], dp[1][1]) + l * k;
}