IOI1994 2번 The Castle

문제 링크

  • http://icpc.me/2234

문제 출처

  • 1994년 국제 정보 올림피아드 2번 (Day1 2번)

사용 알고리즘

  • BFS + bitmask
  • UnionFind + bitmask

시간복잡도

  • O(N4)
  • O(N2 * a(N2))

풀이

BFS 사용

1번 문제와 2번 문제는 단순한 BFS 문제입니다.
다만, 인접한 칸으로 이동할 때 벽이 있는지 없는지 확인을 해야 합니다. 벽의 존재 여부는 bitmask를 사용하여 확인할 수 있습니다.
여기까지는 O(N2)입니다.

3번 문제를 보면, 하나의 벽을 제거하여 얻을 수 있는 가장 넓은 방의 크기를 구해야 합니다. 모든 벽들을 보면서 하나씩 없애보고 BFS를 돌리는 방식으로 답을 구할 수 있습니다.
모든 벽을 순회하는데 O(N2), BFS를 돌리는데 O(N2)이 걸리므로 최종 시간 복잡도는 O(N4) 이 됩니다.

Union Find 사용

인접한 칸을 볼 때 BFS로 탐색을 하는 것이 아니라, Union Find로 합쳐준다고 생각을 해봅시다.

2번 문제는 단순하게 모든 집합(트리)의 사이즈 중 최댓값을 뽑으면 됩니다.
두 방을 Union을 하면 한 방으로 합쳐집니다. 총 방의 개수는 1이 줄어들게 됩니다. 그러므로 N * M - (union 횟수)가 1번 문제의 정답이 됩니다.
이 작업은 Union Find의 연산을 O(N2)번 사용합니다.

마지막으로 3번 문제를 봅시다.
하나의 벽을 제거하여 얻을 수 있는 방의 크기를 구해야 합니다.
인접한 칸이지만 벽으로 인해 union되지 않은 두 방의 크기의 합들을 모두 보면서 최댓값을 구해주면 됩니다.
이 작업도 Union Find의 연산을 O(N2)번 사용합니다.

그러므로 최종 시간 복잡도는 O(N2 * a(N2))입니다.

전체 코드

BFS 사용

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
#include <bits/stdc++.h>
using namespace std;

typedef pair<int, int> p;

const int LEFT = 1;
const int UP = 2;
const int RIGHT = 4;
const int DOWN = 8;

int n, m;
int arr[55][55];
int tmp[55][55];
int chk[55][55];

bool bound(int i, int j){
	return 1 <= i && i <= n && 1 <= j && j <= m;
}

int bfs(int i, int j){
	queue<p> q; q.push({i, j});
	chk[i][j] = 1;
	int cnt = 0;
	while(!q.empty()){
		cnt++;
		int i = q.front().first;
		int j = q.front().second;
		//cout << i << " " << j << "\n";
		q.pop();
		if(!(tmp[i][j] & LEFT) && !chk[i][j-1] && bound(i, j-1)){
			q.push({i, j-1}); chk[i][j-1] = 1;
		}
		if(!(tmp[i][j] & UP) && !chk[i-1][j] && bound(i-1, j)){
			q.push({i-1, j}); chk[i-1][j] = 1;
		}
		if(!(tmp[i][j] & RIGHT) && !chk[i][j+1] && bound(i, j+1)){
			q.push({i, j+1}); chk[i][j+1] = 1;
		}
		if(!(tmp[i][j] & DOWN) && !chk[i+1][j] && bound(i+1, j)){
			q.push({i+1, j}); chk[i+1][j] = 1;
		}
	}
	return cnt;
}

int getCnt(){
	int ret = 0;
	memset(chk, 0, sizeof chk);
	for(int i=1; i<=n; i++){
		for(int j=1; j<=m; j++){
			if(chk[i][j]) continue;
			bfs(i, j); ret++;
		}
	}
	return ret;
}

int getArea(){
	int ret = 0;
	memset(chk, 0, sizeof chk);
	for(int i=1; i<=n; i++){
		for(int j=1; j<=m; j++){
			if(chk[i][j]) continue;
			ret = max(ret, bfs(i, j));
		}
	}
	return ret;
}

void getAns(){
	memcpy(tmp, arr, sizeof tmp);
	cout << getCnt() << "\n";
	cout << getArea() << "\n";

	int ans = 0;
	for(int i=1; i<=n; i++){
		for(int j=1; j<=m; j++){
			for(int k=0; k<4; k++){
				if(arr[i][j] & (1 << k)){
					memcpy(tmp, arr, sizeof tmp);
					tmp[i][j] -= (1 << k);
					ans = max(ans, getArea());
				}
			}
		}
	}
	cout << ans;
}

int main(){
	ios_base::sync_with_stdio(0); cin.tie(0);
	cin >> m >> n;
	for(int i=1; i<=n; i++){
		for(int j=1; j<=m; j++){
			cin >> arr[i][j];
		}
	}
	getAns();
}

Union Find 사용

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
#include <bits/stdc++.h>
using namespace std;

int bit[] = {1, 2, 4, 8};
int di[] = {0, -1, 0, 1};
int dj[] = {-1, 0, 1, 0};

struct UnionFind{
	int par[2525];
	int sz[2525];

	UnionFind(){
		for(int i=0; i<2525; i++) par[i] = i, sz[i] = 1;
	}

	int find(int v){
		return v == par[v] ? v : par[v] = find(par[v]);
	}

	bool merge(int u, int v){
		u = find(u), v = find(v);
		if(u == v) return 0;
		if(sz[u] > sz[v]) swap(u, v);
		par[u] = v;
		sz[v] += sz[u];
		return 1;
	}

	int size(int i, int j){
		int v = 50 * (i-1) + j;
		v = find(v);
		return sz[v];
	}
}uf;

int n, m;
int arr[55][55];
int cnt;

inline int f(int i, int j){ return 50 * (i-1) + j; }
inline bool bound(int i, int j){ return 1<=i && i<=n && 1<=j && j<=m; }

int main(){
	ios_base::sync_with_stdio(0); cin.tie(0);
	cin >> m >> n; cnt = n * m;
	for(int i=1; i<=n; i++){
		for(int j=1; j<=m; j++){
			cin >> arr[i][j];
		}
	}

	for(int i=1; i<=n; i++){
		for(int j=1; j<=m; j++){
			for(int k=0; k<4; k++){
				int ii = i + di[k];
				int jj = j + dj[k];
				int x = bit[k];
				if(!bound(ii, jj)) continue;
				if(arr[i][j] & x) continue;
				if(uf.merge(f(i, j), f(ii, jj))) cnt--;
			}
		}
	}

	int mx = 0;
	for(int i=1; i<=n; i++){
		for(int j=1; j<=m; j++){
			mx = max(mx, uf.size(i, j));
		}
	}

	cout << cnt << "\n" << mx << "\n";

	mx = 0;
	for(int i=1; i<=n; i++){
		for(int j=1; j<=m; j++){
			for(int k=0; k<4; k++){
				int ii = i + di[k];
				int jj = j + dj[k];
				int x = bit[k];
				if(!bound(ii, jj)) continue;
				if(uf.find(f(i, j)) == uf.find(f(ii, jj))) continue;
				if(!(arr[i][j] & x)) continue;
				mx = max(mx, uf.size(i, j) + uf.size(ii, jj));
			}
		}
	}

	cout << mx;
}