백준2867 수열의 값

문제 링크

  • http://icpc.me/2867

문제 출처

  • COCI 2010/2011 Contest #3 5번

사용 알고리즘

  • Stack

시간복잡도

  • O(n)

풀이

이 문제의 Naive한 풀이는 O(n3)의 시간 복잡도를 갖습니다.
이 문제에서 n은 최대 3e5이기 때문에 최소한 O(n log n)안에는 끝내야 합니다.
stack을 활용하면 O(n)만에 해결할 수 있고, O(n)이면 주어진 시간 내에 여유롭게 끝낼 수 있습니다.

이 문제는 4가지의 지표를 이용해 정답을 구합니다.

  1. 왼쪽에서 특정 지점까지는 내가 최댓값이다. -> A
  2. 왼쪽에서 특정 지점까지는 내가 최솟값이다. -> C
  3. 오른쪽에서 특정 지점까지는 내가 최댓값이다. -> B
  4. 오른쪽에서 특정 지점까지는 내가 최솟값이다. -> D

위 4가지 지표를 이용해서 특정 값이 최댓값인 구간과, 최솟값인 구간을 알 수 있습니다.
이 문제는 연속한 부분 수열의 최댓값과 최솟값을 구해서 계산하는 문제이므로 구간을 미리 구한다면 편하게 문제를 풀 수 있습니다.

현재 원소를 K번째 원소라고 가정합시다.
[A, B]범위에서는 K번째 원소가 최댓값이기 때문에 [A, K], (K, B] 범위에서 각각 하나씩 고른 값을 I, J라고 한다면 [I, J] 범위에서도 K번째 값이 최댓값입니다.
해당 경우의 수는 (K-A+1) * (B-K) 이기 때문에 정답에 K번째 원소를 (K-A+1) * (B-K)번 더해줍니다.

[C, D]범위에서도 K번째 원소가 최솟값이라는 점만 뺀다면 위의 경우와 같습니다. 다만, 정답에 K번째 원소를 (K-C+1) * (D-K)번 빼줍니다.

A, B, C, D는 스택을 이용해 쉽게 구할 수 있습니다.
A, C는 스택을 순 감소 상태로 유지하면서 쉽게 구할 수 있습니다.
B, D는 스택을 단조 증가 상태로 유지하면서 쉽게 구할 수 있습니다.
(A, C를 단조 감소 상태로 유지하고 B, D를 순 증가 상태로 유지해도 됩니다.)
원리는 이 문제(클릭)와 같습니다.

전체 코드

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
#include <bits/stdc++.h>
using namespace std;

typedef long long ll;
typedef pair<ll, ll> p;

int n;
vector<p> s;
vector<int> leftMin, leftMax, rightMin, rightMax;
vector<ll> v;

int main(){
	ios_base::sync_with_stdio(0); cin.tie(0);
	cin >> n;
	leftMin.resize(1+n); leftMax.resize(1+n); rightMin.resize(1+n); rightMax.resize(1+n); v.resize(1+n);
	for(int i=1; i<=n; i++) cin >> v[i];

	s.clear();
	for(int i=1; i<=n; i++){
		while(!s.empty() && s.back().first <= v[i]) s.pop_back();
		leftMax[i] = s.empty()?0:s.back().second;
		s.push_back({v[i], i});
	}

	s.clear();
	for(int i=n; i>0; i--){
		while(!s.empty() && s.back().first < v[i]) s.pop_back();
		rightMax[i] = s.empty()?n+1:s.back().second;
		s.push_back({v[i], i});
	}

	s.clear();
	for(int i=1; i<=n; i++){
		while(!s.empty() && s.back().first >= v[i]) s.pop_back();
		leftMin[i] = s.empty()?0:s.back().second;
		s.push_back({v[i], i});
	}

	s.clear();
	for(int i=n; i>0; i--){
		while(!s.empty() && s.back().first > v[i]) s.pop_back();
		rightMin[i] = s.empty()?n+1:s.back().second;
		s.push_back({v[i], i});
	}

	ll ans = 0;
	for(int i=1; i<=n; i++){
		ll cnt = (long long)(rightMin[i] - i) * (i - leftMin[i]);
		ans -= v[i] * cnt;
		cnt = (long long)(rightMax[i] - i) * (i - leftMax[i]);
		ans += v[i] * cnt;
	}
	cout << ans;
}