백준17624 검은 돌

문제 링크

  • http://icpc.me/17624

문제 출처

  • 2019 KOI 2차대회 고등부 3번

사용 알고리즘

  • Tree DP

풀이

대회 도중에 문제를 보면서

  • Tree DP를 해야한다는 것과
  • 어떤 정점 v를 포함하면서 검은 돌은 i개 갖고 있는 서브 트리의 크기의 최대/최소를 구하면 될 것 같다

는 생각을 했었습니다.

구현을 못해서, 그리고 시간이 얼마 남지 않아서 못 풀긴 했습니다만…

dp1[v][i] = v를 포함하면서 검은 돌을 i개 갖고 있는 서브트리의 최소 크기
dp2[v][i] = v를 포함하면서 검은 돌을 i개 갖고 있는 서브트리의 최대 크기
라고 정의해놓고 tree dp를 돌려서 값을 미리 구해놓는다면, 각 쿼리는 O(1)에 처리할 수 있습니다.

v를 포함하면서 검은 돌은 i개 갖고 있는 서브 트리의 크기가 연속적인 것의 증명은 koosaga님 블로그에 잘 나와있습니다.

전체 코드

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
#include <bits/stdc++.h>
#define x first
#define y second
#define all(v) v.begin(), v.end()
#define pb push_back
using namespace std;

typedef long long ll;
typedef unsigned long long ull;
typedef pair<int, int> p;
const int inf = 1e9+7;

int n, m, q, ans;
vector<int> g[5050];
int dp1[5050][5050];
int dp2[5050][5050];
p res[5050];
int sz[5050], chk[5050];

void dfs(int v = 1, int p = -1){
    int now = 0;
    int arr[5050], brr[5050];
    for(auto i : g[v]){
        if(p == i) continue;
        dfs(i, v);
        memset(arr, 0x3f, sizeof arr);
        memset(brr, 0, sizeof brr);
        for(int j=0; j<=sz[i]; j++){
			for(int k=0; k<=now; k++){
				arr[j+k] = min(arr[j+k], dp1[i][j] + dp1[v][k]);
				brr[j+k] = max(brr[j+k], dp2[i][j] + dp2[v][k]);
			}
		}
		now += sz[i];
		for(int j=0; j<=now; j++){
			dp1[v][j] = arr[j];
			dp2[v][j] = brr[j];
		}
    }
	for(int i=0; i<=now; i++){
		dp1[v][i]++; dp2[v][i]++;
	}
    sz[v] = now;
	if(chk[v]){
        sz[v]++;
		for(int i=sz[v]; i; i--){
			dp1[v][i] = dp1[v][i-1];
			dp2[v][i] = dp2[v][i-1];
		}
		dp1[v][0] = 0;
		dp2[v][0] = 0;
	}else{
        dp1[v][0] = 0;
    }
}

int main(){
    ios_base::sync_with_stdio(0); cin.tie(0);
    cin >> n >> m;
    for(int i=1; i<=m; i++){
        int a; cin >> a; chk[a] = 1;
    }
    for(int i=1; i<n; i++){
        int s, e; cin >> s >> e;
        g[s].pb(e); g[e].pb(s);
    }

    for(int i=0; i<=m; i++) res[i] = {inf, -inf};
    memset(dp1, 0x3f, sizeof dp1);

    for(int i=1; i<=n; i++) dp1[i][0] = dp2[i][0] = 0;
    dfs();

    for(int i=1; i<=n; i++){
	    for(int j=0; j<=sz[i]; j++){
			res[j].x = min(res[j].x, dp1[i][j]);
			res[j].y = max(res[j].y, dp2[i][j]);
		}
	}
    cin >> q;
    while(q--){
        int a, b; cin >> a >> b;
        if(res[b].x <= a && a <= res[b].y) ans++;
    }
    cout << ans;

}