백준6549 히스토그램에서 가장 큰 직사각형

문제 링크

  • http://icpc.me/6549

문제 출처

  • University of Ulm Local Contest 2003 H번

사용 알고리즘

  • 스택

시간복잡도

  • O(n)

풀이

먼저 Naive한 풀이는 O(n2)입니다. 모든 경우에 대해 탐색을 하면 되죠.
이 문제는 n이 최대 1e5이기 때문에 O(nlogn) 혹은 O(n)으로 풀어야 합니다.
Segment Tree를 사용하면 O(nlogn)에 가능합니다만, 구현이 더 쉽고 더 빠른 스택을 이용한 O(n) 풀이를 설명하겠습니다.

기본적인 아이디어를 알아봅시다.
left[i]라는 배열에는 i부터 왼쪽으로 뻗어나갈 때 어디까지 뻗어나갈 수 있는지를, right[i]는 오른쪽으로 어디까지 뻗어나갈 수 있는지를 나타냅니다.

0번 인덱스부터 2 2 4 3 1 3 3 이 순서대로 입력된다고 가정합시다.
0번 인덱스에 있는 2는 왼쪽 끝까지 뻗어나갈 수 있으므로 left[0] = -1
1번 인덱스에 있는 2도 왼쪽 끝까지 뻗어나갈 수 있으므로 left[1] = -1
2번 인덱스에 있는 4는 왼쪽으로 뻗어나가다 보면 1번 인덱스인 2에 의해 가로막히게 됩니다. 그러므로 left[2] = 1
3번 인덱스에 있는 3도 왼쪽으로 뻗어나가다 보면 1번 인덱스인 2에 의해 가로막히게 됩니다. 그러므로 left[3] = 1
4번 인덱스에 있는 1은 왼쪽 끝까지 뻗어나갈 수 있으므로 left[4] = -1
5번 인덱스에 있는 3은 왼쪽으로 뻗어나가다 보면 4번 인덱스인 1에 의해 가로막히게 됩니다. 그러므로 left[5] = 4
6번 인덱스에 있는 3은 왼쪽으로 뻗어나가다 보면 4번 인덱스인 1에 의해 가로막히게 됩니다. 그러므로 left[6] = 4

left배열을 naive히게 구현하면 O(n2)이 나오게 됩니다. 그러면 left배열을 구하는 목적이 사라지게 됩니다.
스택을 사용하면 O(n)만에 구할 수 있습니다.
스택에 들어가는 원소들을 순증가 상태를 유지하게 해야 합니다.
그렇게 하기 위해서는 현재 값보다 스택의 top에 있는 값이 같거나 큰 경우에는 pop을 해야 합니다.
그러면 스택은 대략적으로 아래와 같은 상태를 유지하게 됩니다. 이렇게 된다면, 10은 5가 나오기 이전의 모든 공간에서 뻗어나갈 수 있고, 15는 10이 나오기 이전의 모든 공간에서 뻗어나갈 수 있습니다.
스택의 원소들을 순증가 상태로 유지하면 이러한 특성을 이용해 최적화를 할 수 있습니다.

left를 구해봅시다.

1
2
3
4
5
6
7
8
vector<ll> v(n), left(n);
vector<p> stk;
for(int i=0; i<n; i++) scanf("%lld", &v[i]);
for(int i=0; i<n; i++){
	while(!stk.empty() && stk.back().first>=v[i]) stk.pop_back();
	left[i] = stk.empty()?-1:stk.back().second;
	stk.push_back({v[i], i});
}

right는 좌우 반전을 해주면 똑같이 구현 가능합니다.

1
2
3
4
5
for(int i=n-1; i>=0; i--){
	while(!stk.empty() && stk.back().first>=v[i]) stk.pop_back();
	right[i] = stk.empty()?n:stk.back().second;
	stk.push_back({v[i], i});
}

이제 left[i]에는 왼쪽으로 뻗어나갈 수 있는 최대, right[i]에는 오른쪽으로 뻗어나갈 수 있는 최대 위치를 구했습니다. 이제 양쪽으로 얼마나 뻗어나갈 수 있을지를 구해야 합니다.
right[i] - left[i] - 1이 양쪽으로 뻗어나갈 수 있는 최대 거리입니다.
코드로 옮겨봅시다.

전체 코드

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
#include <bits/stdc++.h>
using namespace std;

typedef long long ll;
typedef pair<ll, ll> p;

void solve(int n){
	vector<ll> v(n), left(n), right(n);
	vector<p> stk;
	ll ans;
	for(int i=0; i<n; i++) scanf("%lld", &v[i]);
	for(int i=0; i<n; i++){
		while(!stk.empty() && stk.back().first>=v[i]) stk.pop_back();
		left[i] = stk.empty()?-1:stk.back().second;
		stk.push_back({v[i], i});
	}
	stk.clear();
	for(int i=n-1; i>=0; i--){
		while(!stk.empty() && stk.back().first>=v[i]) stk.pop_back();
		right[i] = stk.empty()?n:stk.back().second;
		stk.push_back({v[i], i});
	}
	stk.clear();

	for(int i=0; i<n; i++){
		ll t = v[i] * (right[i] - left[i] - 1);
		ans = max(ans, t);
	}
	printf("%lld\n", ans);
}

int main(){
	int n;
	while(scanf("%d", &n) != 0){
		if(!n) return 0;
		solve(n);
	}
}