백준2454 트리 분할

문제 링크

  • http://icpc.me/2454

문제 출처

  • 2011 KOI 고등부 3번

사용 알고리즘

  • 그리디

시간복잡도

  • $O(N \log N)$

풀이

일단 정점 하나를 잡아서 Rooted Tree로 만들어줍시다.

경로 단위로 묶기 때문에 같은 정점에 달려있는 서로 다른 세 간선은 한 집합에 묶일 수 없습니다. 그러므로 각 정점은 아래 3가지 상태 중 한 가지 상태에 속하게 됩니다.

  1. 서로 다른 두 자식 정점과 같이 묶인다.
  2. 자식 정점 한 개와 묶이고, 부모 정점과 묶일 가능성이 있다.
  3. 자식 정점과 묶이지 않고, 부모 정점과 묶일 가능성이 있다.

Rooted Tree에서 Tree DP 비슷한 느낌으로 각 정점 $i$마다 $(A_i, B_i)$ pair를 관리할 것입니다.

  • $A_i$: $i$번 정점을 루트로 하는 서브 트리에서의 정답
  • $B_i$: 이때 $i$번 정점을 포함하는 경로에 속한 정점 개수의 최솟값

당연히 $B_i$는 $K+1$ 이하가 되어야 합니다.
이제, 위에서 설명한 3가지 상태를 고려해서 정답을 찾아봅시다.

  1. 서로 다른 두 정점과 묶이는 경우: $i$의 자식 정점 두 개를 골라서 $B$값의 합이 $K$ 이하라면, 그 두 정점을 포함한 집합과 $i$를 묶을 수 있습니다. 이 경우 부모 정점과 묶일 수 없기 때문에 $B_i$를 $K+1$로 설정해줍니다. 그러므로 $(A_i, B_i) = (\sum A - 1, K+1)$이 됩니다.
  2. 자식 정점 한 개와 묶이는 경우: $i$의 자식 정점을 하나 골라서 $B$값이 $K$ 이하라면, 그 정점을 포함한 집합과 $i$를 묶을 수 있습니다. $(A_i, B_i) = (\sum A, \min(B) + 1)$이 됩니다.
  3. 자식 정점과 묶이지 않는 경우: $(A_i, B_i) = (\sum A+1, 1)$이 됩니다.

각 정점마다 가능한 $(A_i, B_i)$의 후보가 여러 개 있을 수 있는데, $A_i$가 가장 작은 것, $A_i$가 같다면 $B_i$가 가장 작은 것을 선택하는 것이 최적이라는 것을 증명할 수 있습니다.

전체 코드

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
/******************************
Author: jhnah917(Justice_Hui)
g++ -std=c++17 -DLOCAL
******************************/

#include <bits/stdc++.h>
#define x first
#define y second
#define all(v) v.begin(), v.end()
#define compress(v) sort(all(v)), v.erase(unique(all(v)), v.end())
#define IDX(v, x) (lower_bound(all(v), x) - v.begin())
using namespace std;

using uint = unsigned;
using ll = long long;
using ull = unsigned long long;
using PII = pair<int, int>;

int N, K;
vector<int> G[303030];

PII DFS(int v, int b){
    vector<PII> ch;
    for(const auto &i : G[v]) if(i != b) ch.push_back(DFS(i, v));
    sort(all(ch), [](const PII &p1, const PII &p2){ return p1.y < p2.y; });

    int sum = 0;
    for(const auto &i : ch) sum += i.x;

    PII ret(sum+1, 1);
    if(ch.size() >= 1 && ch[0].y < K+1) ret = min(ret, PII(sum,ch[0].y+1));
    if(ch.size() >= 2 && ch[0].y + ch[1].y < K+1) ret = min(ret, PII(sum-1, K+1));
    return ret;
}

int main(){
    ios_base::sync_with_stdio(false); cin.tie(nullptr);
    cin >> N >> K;
    for(int i=1; i<N; i++){
        int s, e; cin >> s >> e;
        G[s].push_back(e); G[e].push_back(s);
    }
    cout << DFS(1, -1).x;
}