백준2336 굉장한 학생

문제 링크

  • http://icpc.me/2336

문제 출처

  • 2004 BalkanOI Day1 2번

사용 알고리즘

  • Segment Tree

시간복잡도

  • O(n log n)

풀이

자신보다 3개의 시험을 모두 잘 본 학생이 없는 학생은 굉장한 학생입니다.

먼저, 첫 번째 시험의 등수를 기준으로 정렬을 합니다.
이렇게 하면, 먼저 나온 학생이 적어도 그 뒤에 나온 학생보다 첫 번째 시험을 잘 봤다는 것은 자명한 사실이라는 것을 알 수 있고, 이를 통해 뒤에 나온 학생은 앞에 나온 학생보다 굉장한 학생일 확률이 0%가 됩니다.
그러므로 첫 번째 시험의 등수를 기준으로 정렬을 한 뒤, 특정 학생보다 앞에 있는 학생 중 2, 3번째 시험을 동시에 잘본 학생이 없다면 해당 학생은 굉장한 학생이라고 할 수 있습니다.

위에서 찾아낸 조건을 처리하기 위해서 Min Segment Tree를 사용할 것입니다.
먼저, 모든 노드를 INF값으로 초기화를 해줍니다.
그 후, 쿼리를 처리할 때마다 2번째 등수의 위치에 3번째 등수를 업데이트 해줍니다.
이렇게 하면, 1~나의 2번째 등수 구간을 확인하는 것과 나보다 1, 2번째 시험을 잘 본 학생들의 구간을 확인하는 것은 동치입니다.
구간에는 3번째 시험의 등수가 업데이트 되어있을 것이기 때문에 구간의 최솟값이 나의 3번째 시험의 등수보다 작다면 나보다 대단한 학생이 존재하는 것이므로 현재 처리하는 쿼리는 굉장한 학생이 아니게 됩니다.

전체 코드

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
#include <bits/stdc++.h>
using namespace std;

struct asdf{

	int x, y, z;
};

bool cmp(asdf a, asdf b){
	return a.x < b.x;
}

const int size = 500000+10;
asdf arr[size];
int tree[4*size];

int update(int node, int start, int end, int idx, int value){
	if(idx < start || end < idx) return tree[node];
	if(start == end) return tree[node] = value;
	int mid = (start+end) >> 1;
	return tree[node] = min(
		update(node*2, start, mid, idx, value),
		update(node*2+1, mid+1, end, idx, value)
	);
}

int get(int node, int start, int end, int left, int right){
	if(right < start || end < left) return 1e9+7;
	if(left <= start && end <= right) return tree[node];
	int mid = (start+end) >> 1;
	return min(
		get(node*2, start, mid, left, right),
		get(node*2+1, mid+1, end, left, right)
	);
}

int main(){
	ios_base::sync_with_stdio(0); cin.tie(0);
	int n; cin >> n;
	for(int i=1; i<=n; i++){
		int t; cin >> t; arr[t].x = i;
	}
	for(int i=1; i<=n; i++){
		int t; cin >> t; arr[t].y = i;
	}
	for(int i=1; i<=n; i++){
		int t; cin >> t; arr[t].z = i;
	}
	sort(arr+1, arr+n+1, cmp);
	for (int i=1; i<=n; i++) update(1, 1, n, i, 1e9+7);
	int ans = 0;
	for (int i=1; i<=n; i++) {
        if (get(1, 1, n, 1, arr[i].y) > arr[i].z) ans++;
        update(1, 1, n, arr[i].y, arr[i].z);
    }

	cout << ans;
}