백준7469 K번째 수

문제 링크

  • http://icpc.me/7469

문제 출처

  • 2004 NEERC Northern Subregional K번

사용 알고리즘

  • Segment tree
  • Merge Sort Tree

시간복잡도

  • O(n log n + m log3 n)

풀이

이 문제는 segment Tree로 해결할 수 있는 문제입니다.
이 문제를 풀면서 나오는 Segment Tree의 형태가 Merge Sort의 진행 과정과 유사하다 해서 Merge Sort Tree라고 부르기도 합니다.

입력된 숫자들을 세그먼트 트리를 이용해 관리를 합니다. 트리의 각 노드는 대응하는 구간의 원소들을 정렬한 배열을 키값으로 갖습니다.
전에 올린 풀이들은 노드들이 상수를 키값으로 가졌지만, 이 문제에서는 배열을 키값으로 갖습니다.


이런 식으로 merge sort tree 형태를 갖게 됩니다.

이제 k번째 수를 찾아봅시다.

먼저, 특정 구간에서 k번째 수가 x라면 아래 두 가지를 만족합니다.

  • 해당 구간에 x 이하의 수는 k개 이상 존재
  • 해당 구간에 x 미만의 수는 k개 미만 존재

어떤 구간에서 x 이하의 수의 개수는 다음과 같이 재귀적으로 구할 수 있습니다.

  • 원하는 구간과 현재 탐색 중인 구간이 전혀 교차하지 않는다면, 0개
  • 원하는 구간이 현재 탐색 중인 구간을 완전히 포함한다면, 현재 탐색 중인 구간에서 이진 탐색
  • 일부만 교차한다면, 2개의 자식 노드에 대해 재귀적으로 위에 있는 두 연산을 수행

이러한 방식으로 답을 구한다면, 정렬을 하는데 O(n log n)이 걸리고, O(log3 n)만에 k번째 수를 찾을 수 있습니다.
따라서 전체 시간 복잡도는 O(n log n + m log3 n)입니다.

전체 코드

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
#include <bits/stdc++.h>
using namespace std;

const int S = 1e6;
int n, m;
int arr[S];
int sorted[S];
vector<int> data[S*2];

void update(int bucket, int node, int start, int end, int x){
	if(node<start || end<node) return;
	data[bucket].push_back(x);
	if(start != end){
		update(bucket*2, node, start, (start+end)/2, x);
		update(bucket*2+1, node, (start+end)/2+1, end, x);
	}
}

int get(int bucket, int start, int end, int left, int right, int x){
	if(left>end || right<start) return 0;
	if(left<=start && end<=right) return upper_bound(data[bucket].begin(), data[bucket].end(), x) - data[bucket].begin();
	return get(bucket*2, start, (start+end)/2, left, right, x) + get(bucket*2+1, (start+end)/2+1, end, left, right, x);
}

int main(){
	int n, m; cin >> n >> m;
	for(int i=1; i<=n; i++){
		cin >> arr[i];
		update(1, i, 1, n, arr[i]);
	}

	for(int i=0; i<S*2; i++) sort(data[i].begin(), data[i].end());

	while(m--){

		int a, b, c; cin >> a >> b >> c;
		int l = -1e9, r = 1e9;
		int mid = (l+r)/2;
		while(l<=r){
			mid = (l+r)/2;
			if(get(1, 1, n, a, b, mid) < c) l = mid+1;
			else r = mid-1;
		}
		cout << l << "\n";
	}
}