백준2213 트리의 독립집합

문제 링크

  • http://icpc.me/2213

사용 알고리즘

  • Tree DP

시간복잡도

  • O(n)

풀이

그래프에서 정점의 부분집합 S에 있는 모든 정점의 쌍이 인접하지 않는 집합을 독립집합이라고 정의했습니다
만약 정점에 가중치가 없다면 독립집합의 크기는 정점의 수이고, 가중치가 있다면 독립집합의 크기는 가중치의 합이 됩니다.
이 문제에서는 정점에 가중치가 있는 트리에서 최대 독립 집합을 구하는 것을 요구하고 있습니다.

트리는 비선형 구조입니다. 탐색 순서를 정하기 위해 dfs 트리를 만들어줍시다.
이 때 g[][]는 입력 데이터를 이용해 만든 인접리스트, tree[][]는 dfs트리를 저장할 인접리스트입니다.

1
2
3
4
5
6
7
8
9
void dfs(int now, int prv){
	for(int i=0; i<g[now].size(); i++){
		int nxt = g[now][i];
		if(nxt != prv){
			tree[now].push_back(nxt);
			dfs(nxt, now);
		}
	}
}

탐색 순서를 정했다면 dp배열을 정의해봅시다.
dp[i][0] = i번 노드를 루트로 하는 서브트리에서, i노드를 포함하지 않는 경우의 답
dp[i][1] = i번 노드를 루트로 하는 서브트리에서, i노드를 포함하는 경우의 답
으로 정의를 합시다.

dp[i][1]은 i번 노드를 포함한다는 의미이기 때문에 자식 노드들은 모두 포함하지 않습니다. 그러므로 dp[i][1]은 모든 dp[k][0]의 합이 될 것입니다.(단, k는 i번 노드의 자식 노드입니다.)
dp[i][0]은 i번 노드를 포함하지 않는다는 의미이기 때문에 자식 노드는 포함해도 되고 포함하지 않아도 됩니다. 문제에서는 최대 독립 집합을 요구하므로 포함하는 경우와 그렇지 않은 경우 중, 더 큰 값을 채택해야 합니다. 그러므로 dp[i][0]은 모든 max(dp[k][0], dp[k][1])의 합이 됩니다. (단, k는 i번 노드의 자식 노드입니다.)
dp table을 채우면서 각 노드의 포함 여부도 함께 기록합시다.

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
int dp(int now, bool include){
	int &res = table[now][include];
	if(res != -1) return res;

	if(include){ //포함하면
		int ans = 0;
		for(int i=0; i<tree[now].size(); i++){
			int nxt = tree[now][i];
			ans += dp(nxt, 0); //다음 노드 포함 안함
		}
		return res = ans + cost[now];
	}else{
		int ans = 0;
		for(int i=0; i<tree[now].size(); i++){
			int nxt = tree[now][i];
			int t1 = dp(nxt, 0);
			int t2 = dp(nxt, 1);
			if(t1 < t2){ //포함한게 더 크면
				chk[nxt] = 1;
			}
			else{
				chk[nxt] = 0;
			}
			ans += max(t1, t2);
		}
		return res = ans;
	}
}

dp table을 채우면서 기록한 각 노드의 포함여부를 이용해 역추적을 해봅시다.
역추적은 루트 노드부터 탐색을 시작해 현재 탐색 중인 노드가 포함이 되어 있다면 정답 배열에 넣어주면 됩니다.

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
void track(int now, int include){
	if(include){
		ansarr.push_back(now);
		for(int i=0; i<tree[now].size(); i++){
			int nxt = tree[now][i];
			track(nxt, 0);
		}
	}
	else{
		for(int i=0; i<tree[now].size(); i++){
			int nxt = tree[now][i];
			track(nxt, chk[nxt]);
		}
	}
}

전체 코드

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
#include <bits/stdc++.h>
using namespace std;

const int inf = 1e9+7;
vector<int> g[10010]; //그래프
vector<int> tree[10010]; //dfs 트리
int table[10010][2]; //dp테이블
int chk[10010]; //포함 여부
int maxi = -inf;
int cost[10010]; //노드 가중치
vector<int> ansarr; //정답

void dfs(int now, int prv){
	for(int i=0; i<g[now].size(); i++){
		int nxt = g[now][i];
		if(nxt != prv){
			tree[now].push_back(nxt);
			dfs(nxt, now);
		}
	}
}

int dp(int now, bool include){
	int &res = table[now][include];
	if(res != -1) return res;

	if(include){ //포함하면
		int ans = 0;
		for(int i=0; i<tree[now].size(); i++){
			int nxt = tree[now][i];
			ans += dp(nxt, 0); //다음 노드 포함 안함
		}
		return res = ans + cost[now];
	}else{
		int ans = 0;
		for(int i=0; i<tree[now].size(); i++){
			int nxt = tree[now][i];
			int t1 = dp(nxt, 0);
			int t2 = dp(nxt, 1);
			if(t1 < t2){ //포함한게 더 크면
				chk[nxt] = 1;
			}
			else{
				chk[nxt] = 0;
			}
			ans += max(t1, t2);
		}
		return res = ans;
	}
}

void track(int now, int include){
	if(include){
		ansarr.push_back(now);
		for(int i=0; i<tree[now].size(); i++){
			int nxt = tree[now][i];
			track(nxt, 0);
		}
	}
	else{
		for(int i=0; i<tree[now].size(); i++){
			int nxt = tree[now][i];
			track(nxt, chk[nxt]);
		}
	}
}

int main(){
	ios_base::sync_with_stdio(0); cin.tie(0); cout.tie(0);
	int n; cin>>n;
	for(int i=1; i<=n; i++){
		cin>>cost[i];
		table[i][0] = table[i][1] = -1;
	}
	for(int i=0; i<n-1; i++){
		int a, b; cin>>a>>b;
		g[a].push_back(b);
		g[b].push_back(a);
	}
	dfs(1, 1);

	int t1 = dp(1, 0);
	int t2 = dp(1, 1);
	if(t1 < t2) chk[1] = 1;
	else chk[1] = 0;
	int maxi = max(t1, t2);
	cout << maxi << "\n";
	track(1, chk[1]);
	sort(ansarr.begin(), ansarr.end());
	for(int i=0; i<ansarr.size(); i++) cout << ansarr[i] << " ";
}