백준16909 카드 구매하기3

문제 링크

  • http://icpc.me/16909

사용 알고리즘

  • Union-Find

풀이

모든 구간의 (최댓값 - 최솟값)의 합을 구하는 것은 최댓값을 다 더한 다음에 최솟값을 다 빼주는 것으로 볼 수 있습니다.
이런 경우, 최댓값을 구하는 방법만 알면 뒤집어서 최솟값도 구할 수 있기 때문에 최댓값을 구하는 방법만 설명하겠습니다.

최댓값이 x라면 x는 존재해야 하고, x보다 큰 값은 존재하면 안 됩니다.
비용이 낮은 카드부터 차례대로 처리해주면 카드를 점점 추가해나가는 형태로 구현해줄 수 있습니다.
어떤 카드 C가 최대인 구간은 C를 추가했을 때 C 왼쪽/오른쪽으로 연달아 놓여있는 카드들입니다.

구현 방법을 조금 더 자세히 짚어보자면, 비용이 x인 카드를 넣는 순간

  1. 이 카드의 왼쪽에 연달아 놓여있는 카드의 개수를 셉니다. 이 값을 L이라고 합시다.
  2. 오른쪽에 대해 같은 작업을 수행해줍니다. 그 값을 R이라고 합시다.
  3. 최댓값의 합은 x(L+1)(R+1)만큼 증가합니다.

3번이 (L+1)(R+1)인 이유는, 삽입한 카드를 포함하면서 구간의 왼쪽과 오른쪽 끝을 정해야 하기 때문입니다.

1, 2번 과정을 빠르게 하는 것은 union find를 이용하면 됩니다.
연달아 놓여있는 카드들을 한 집합으로 합쳐주는 형태로 구현하면 됩니다.

전체 코드

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
#include <bits/stdc++.h>
using namespace std;

typedef long long ll;

int n;
ll arr[1010101];
vector<int> idx;

int par[1010101];
ll lf[1010101];
ll rf[1010101];

void init(){
	for(int i=1; i<=n; i++){
		par[i] = i;
		lf[i] = i, rf[i] = i;
	}
}

int find(int v){
	return v == par[v] ? v : par[v] = find(par[v]);
}

bool merge(int u, int v){
	u = find(u); v = find(v);
	if(u == v) return 0;
	if(u < 1 || v < 1) return 0;
	if(u > n || v > n) return 0;
	par[u] = v;
	lf[v] = min(lf[v], lf[u]);
	rf[v] = max(rf[v], rf[u]);
	return 1;
}

int main(){
	ios_base::sync_with_stdio(0); cin.tie(0);
	cin >> n; init();
	for(int i=1; i<=n; i++) cin >> arr[i], idx.push_back(i);

	sort(idx.begin(), idx.end(), [](int a, int b){
		if(arr[a] != arr[b]) return arr[a] < arr[b];
		return a < b;
	});
	ll ans = 0;
	for(int i=0; i<idx.size(); i++){
		int j = idx[i];
		if(arr[j] >= arr[j-1]) merge(j, j-1);
		if(arr[j] >= arr[j+1]) merge(j, j+1);
		if(i != idx.size()-1 && arr[j] == arr[idx[i+1]]){
			int k = idx[i+1];
			int r = j; if(arr[j] > arr[j+1]) r = rf[find(j+1)];
			if(arr[k] >= arr[k-1]) merge(k, k-1);
			int jj = find(j), kk = find(k);
			if(lf[kk] <= rf[jj]){
				ans += arr[j] * (j - lf[jj] + 1) * (r - j + 1);
				continue;
			}
		}
		int jj = find(j);
		ans += arr[j] * (rf[jj] - j + 1) * (j - lf[jj] + 1);
	}

	init();
	sort(idx.begin(), idx.end(), [](int a, int b){
		if(arr[a] != arr[b]) return arr[a] > arr[b];
		return a < b;
	});
	for(int i=0; i<idx.size(); i++){
		int j = idx[i];
		if(arr[j] <= arr[j-1]) merge(j, j-1);
		if(arr[j] <= arr[j+1]) merge(j, j+1);
		if(i != idx.size()-1 && arr[j] == arr[idx[i+1]]){
			int k = idx[i+1];
			int r = j; if(arr[j] < arr[j+1]) r = rf[find(j+1)];
			if(arr[k] <= arr[k-1]) merge(k, k-1);
			int jj = find(j), kk = find(k);
			if(lf[kk] <= rf[jj]){
				ans -= arr[j] * (j - lf[jj] + 1) * (r - j + 1);
				continue;
			}
		}
		int jj = find(j);
		ans -= arr[j] * (rf[jj] - j + 1) * (j - lf[jj] + 1);
	}
	cout << ans;
}