백준13537 수열과 쿼리1

문제 링크

  • http://icpc.me/13537

사용 알고리즘

  • Segment Tree
  • Merge Sort Tree (online query)

시간복잡도

  • online query
    • O(n log n + m log3 n)
  • offline query
    • O((n+m) log (n+m))

풀이

먼저 online query를 이용한 방식을 설명하겠습니다.

이 문제와 같이 특정 구간에서 k번째 수 ~~~ 혹은 k보다 큰/작은 수 ~~~ 형태의 문제는 Merge Sort Tree라는 것을 이용하여 해결하는 경우가 많습니다.
이 문제를 풀면서 나오는 Segment Tree가 Merge Sort의 진행 과정과 유사하기 때문입니다.

입력된 숫자들을 세그먼트 트리를 이용해 관리를 합니다. 트리의 각 노드는 대응하는 구간의 원소들을 정렬한 배열을 키값으로 갖습니다.


이런 식으로 merge sort tree 형태를 갖게 됩니다.

어떤 구간에서 x 보다 큰 수의 개수는 다음과 같이 재귀적으로 구할 수 있습니다.

  • 원하는 구간과 현재 탐색 중인 구간이 전혀 교차하지 않는다면, 0개
  • 원하는 구간이 현재 탐색 중인 구간을 완전히 포함한다면, 현재 탐색 중인 구간에서 이진 탐색
  • 일부만 교차한다면, 2개의 자식 노드에 대해 재귀적으로 위에 있는 두 연산을 수행

이러한 방식으로 답을 구한다면, Merge Sort Tree를 구축하는데 O(n log n)이 걸리고, O(log3 n)만에 k보다 큰 수를 찾을 수 있습니다.
따라서 전체 시간 복잡도는 O(n log n + m log3 n)입니다.


이번에는 tonyjjw님께서 설명해주신 offline query를 이용한 방식입니다.

먼저, 쿼리를 두 가지 종류로 구분합니다.

  1. 값을 삽입하는 쿼리
  2. 해를 구하는 쿼리

쿼리를 모두 입력받은 후, 쿼리를 다음 기준으로 정렬합니다.

  1. 삽입해야 하는 값 혹은 해를 구해야 하는 값이 큰 순서로
  2. 값이 같다면 해를 구하는 쿼리가 먼저 오도록

해를 구할 때는 k의 값보다 작거나 같은 값들은 전혀 신경을 쓸 필요가 없습니다.
이 성질을 이용해서 정렬을 해, 아직 필요 없는 값은 전혀 신경쓰지 않게 할 수 있습니다.

시간 복잡도는 쿼리를 정렬하는 데 O((n+m) log (n+m)), 쿼리를 수행하는데 O(n log n + m log n)이 소요되어 총 시간 복잡도는 O((n+m) log (n+m))입니다.

전체 코드

online query
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
#include <bits/stdc++.h>
using namespace std;

vector<int> tree[500010];
vector<int> arr;

void update(int bucket, int node, int start, int end, int x){
	if(node<start || end<node) return;
	tree[bucket].push_back(x);
	if(start != end){
		update(bucket*2, node, start, (start+end)/2, x);
		update(bucket*2+1, node, (start+end)/2+1, end, x);
	}
}

int get(int node, int start, int end, int left, int right, int x){
	if(left>end || right<start) return 0;
	if(left<=start && end<=right) return tree[node].end() - upper_bound(tree[node].begin(), tree[node].end(), x);
	int mid = (start+end)>>1;
	return get(node*2, start, mid, left, right, x) + get(node*2+1, mid+1, end, left, right, x);
}

int main(){
	ios_base::sync_with_stdio(0); cin.tie(0);
	int n, m; cin >> n;
	arr.resize(n+10);
	for(int i=1; i<=n; i++){
		cin >> arr[i];
		update(1, i, 1, n, arr[i]);
	}

	for(int i=0; i<500010; i++) sort(tree[i].begin(), tree[i].end());

	cin >> m;
	while(m--){
		int a, b, c; cin >> a >> b >> c;
		cout << get(1, 1, n, a, b, c) << "\n";
	}
}
offline query
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
#include <stdio.h>
#include <algorithm>
#include <vector>

using namespace std;

const int TN = 1 << 17;
const int MN = 100000 + 10;

int N, M;
int A[MN];
int T[2 * TN];
int ans[MN];

struct Query {
	// type = 1: insert
	// type = 2: get answer
	int type, idx, i, j, k;
};

vector<Query> queries;

void add(int n, int v) {
	for (n += TN; n > 0; n /= 2) {
		T[n] += v;
	}
}

int get_sum(int l, int r) {
	int res = 0;

	for (l += TN, r += TN; l <= r; l /= 2, r /= 2) {
		if (l % 2 == 1) {
			res += T[l++];
		}
		if (r % 2 == 0) {
			res += T[r--];
		}
	}

	return res;
}

int main() {
	scanf("%d", &N);

	for (int i = 0; i < N; i++) {
		scanf("%d", &A[i]);
		queries.push_back({ 1, i, 0, 0, A[i] });
	}

	scanf("%d", &M);

	for (int m = 0; m < M; m++) {
		int i, j, k;
		scanf("%d%d%d", &i, &j, &k); i--, j--;
		queries.push_back({ 2, m, i, j, k });
	}

	sort(queries.begin(), queries.end(), [&](Query a, Query b){
		if (a.k != b.k)
			return a.k > b.k;
		return a.type > b.type;
	});

	for (auto q : queries) {
		if (q.type == 1) {
			add(q.idx, 1);
		}
		if (q.type == 2) {
			ans[q.idx] = get_sum(q.i, q.j);
		}
	}

	for (int m = 0; m < M; m++) {
		printf("%d\n", ans[m]);
	}

	return 0;
}